/***********************************************************************
 * Copyright (c) 2013-2025 General Atomics Integrated Intelligence, Inc.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Apache License, Version 2.0
 * which accompanies this distribution and is available at
 * https://www.apache.org/licenses/LICENSE-2.0
 ***********************************************************************/

package org.locationtech.geomesa.arrow.filter

import com.typesafe.scalalogging.LazyLogging
import org.geotools.api.feature.simple.SimpleFeatureType
import org.geotools.api.filter._
import org.geotools.api.filter.expression.PropertyName
import org.geotools.api.filter.spatial.BBOX
import org.geotools.api.filter.temporal.During
import org.geotools.api.temporal.Period
import org.geotools.geometry.jts.ReferencedEnvelope
import org.locationtech.geomesa.arrow.features.ArrowSimpleFeature
import org.locationtech.geomesa.arrow.jts.GeometryVector
import org.locationtech.geomesa.arrow.vector.ArrowAttributeReader._
import org.locationtech.geomesa.arrow.vector.ArrowDictionary
import org.locationtech.geomesa.filter.checkOrderUnsafe
import org.locationtech.geomesa.filter.factory.FastFilterFactory
import org.locationtech.geomesa.utils.geotools.CRS_EPSG_4326
import org.locationtech.geomesa.utils.geotools.converters.FastConverter
import org.locationtech.jts.geom.{Coordinate, Polygon}

import java.util.Date
import scala.util.control.NonFatal

/**
  * Optimizes filters for running against arrow files
  */
object ArrowFilterOptimizer extends LazyLogging {

  import org.locationtech.geomesa.utils.geotools.RichAttributeDescriptors.RichAttributeDescriptor
  import org.locationtech.geomesa.utils.geotools.RichSimpleFeatureType.RichSimpleFeatureType

  import scala.collection.JavaConverters._

  private val ff: FilterFactory = FastFilterFactory.factory

  def rewrite(filter: Filter, sft: SimpleFeatureType, dictionaries: Map[String, ArrowDictionary]): Filter = {
    val bound = FastFilterFactory.optimize(sft, filter)
    FastFilterFactory.sfts.set(sft)
    try {
      rewriteFilter(bound, sft, dictionaries)
    } finally {
      FastFilterFactory.sfts.remove()
    }
  }

  private def rewriteFilter(filter: Filter, sft: SimpleFeatureType, dictionaries: Map[String, ArrowDictionary]): Filter = {
    try {
      filter match {
        case f: BBOX              => rewriteBBox(f, sft)
        case f: During            => rewriteDuring(f, sft)
        case f: PropertyIsBetween => rewriteBetween(f, sft)
        case f: PropertyIsEqualTo => rewritePropertyIsEqualTo(f, sft, dictionaries)
        case a: And               => ff.and(a.getChildren.asScala.map(rewriteFilter(_, sft, dictionaries)).asJava)
        case o: Or                => ff.or(o.getChildren.asScala.map(rewriteFilter(_, sft, dictionaries)).asJava)
        case f: Not               => ff.not(rewriteFilter(f.getFilter, sft, dictionaries))
        case _                    => filter
      }
    } catch {
      case NonFatal(e) => logger.warn(s"Error re-writing filter $filter", e); filter
    }
  }

  private def rewriteBBox(filter: BBOX, sft: SimpleFeatureType): Filter = {
    if (sft.isPoints || sft.isLines) {
      val props = checkOrderUnsafe(filter.getExpression1, filter.getExpression2)
      val bbox = FastConverter.evaluate(props.literal, classOf[Polygon]).getEnvelopeInternal
      val attrIndex = sft.indexOf(props.name)
      if (sft.isPoints) {
        ArrowPointBBox(attrIndex, bbox.getMinX, bbox.getMinY, bbox.getMaxX, bbox.getMaxY)
      } else {
        ArrowLineStringBBox(attrIndex, bbox.getMinX, bbox.getMinY, bbox.getMaxX, bbox.getMaxY)
      }
    } else {
      filter
    }
  }

  private def rewriteDuring(filter: During, sft: SimpleFeatureType): Filter = {
    val props = checkOrderUnsafe(filter.getExpression1, filter.getExpression2)
    val attrIndex = sft.indexOf(props.name)
    val period = FastConverter.evaluate(props.literal, classOf[Period])
    val lower = period.getBeginning.getPosition.getDate.getTime
    val upper = period.getEnding.getPosition.getDate.getTime
    ArrowDuring(attrIndex, lower, upper)
  }

  private def rewriteBetween(filter: PropertyIsBetween, sft: SimpleFeatureType): Filter = {
    val attribute = filter.getExpression.asInstanceOf[PropertyName].getPropertyName
    val attrIndex = sft.indexOf(attribute)
    if (sft.getDescriptor(attrIndex).getType.getBinding != classOf[Date]) { filter } else {
      val lower = FastConverter.evaluate(filter.getLowerBoundary, classOf[Date]).getTime
      val upper = FastConverter.evaluate(filter.getUpperBoundary, classOf[Date]).getTime
      ArrowBetweenDate(attrIndex, lower, upper)
    }
  }

  private def rewritePropertyIsEqualTo(filter: PropertyIsEqualTo,
                                       sft: SimpleFeatureType,
                                       dictionaries: Map[String, ArrowDictionary]): Filter = {
    val props = checkOrderUnsafe(filter.getExpression1, filter.getExpression2)
    dictionaries.get(props.name) match {
      case None => filter
      case Some(dictionary) =>
        val attrIndex = sft.indexOf(props.name)
        val numericValue = dictionary.index(props.literal.evaluate(null))
        if (sft.getDescriptor(attrIndex).isList) {
          ArrowListDictionaryEquals(attrIndex, numericValue)
        } else {
          ArrowDictionaryEquals(attrIndex, numericValue)
        }
    }
  }

  case class ArrowPointBBox(i: Int, xmin: Double, ymin: Double, xmax: Double, ymax: Double) extends Filter {
    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      val reader = arrow.getReader(i).asInstanceOf[ArrowPointReader]
      val index = arrow.getIndex
      val y = reader.readPointY(index)
      if (y < ymin || y > ymax) { false } else {
        val x = reader.readPointX(index)
        x >= xmin && x <= xmax
      }
    }
  }

  case class ArrowLineStringBBox(i: Int, xmin: Double, ymin: Double, xmax: Double, ymax: Double) extends Filter {
    private val bboxEnvelope = new ReferencedEnvelope(xmin, xmax, ymin, ymax, CRS_EPSG_4326)
    private val bbox = GeometryVector.factory.toGeometry(bboxEnvelope)

    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      val reader = arrow.getReader(i).asInstanceOf[ArrowLineStringReader]
      val (start, end) = reader.readOffsets(arrow.getIndex)
      var offset = start
      val points = Array.ofDim[Coordinate](2)
      while (offset < end) {
        val y = reader.readPointY(offset)
        if (y >= ymin && y <= ymax) {
          val x = reader.readPointX(offset)
          if (x >= xmin && x <= xmax) {
            // we have a point in the bbox, short-circuit and return
            return true
          }
          // check for intersection even if the points aren't contained in the bbox
          points(1) = points(0)
          points(0) = new Coordinate(x, y)
          if (offset > start) {
            val line = GeometryVector.factory.createLineString(points)
            if (line.getEnvelopeInternal.intersects(bboxEnvelope) && line.intersects(bbox)) {
              // found an intersection, short-circuit and return
              return true
            }
          }
        }
        offset += 1
      }
      false
    }
  }

  case class ArrowDuring(i: Int, lower: Long, upper: Long) extends Filter {
    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      val time = arrow.getReader(i).asInstanceOf[ArrowDateReader].getTime(arrow.getIndex)
      // note that during is exclusive
      time > lower && time < upper
    }
  }

  case class ArrowBetweenDate(i: Int, lower: Long, upper: Long) extends Filter {
    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      val time = arrow.getReader(i).asInstanceOf[ArrowDateReader].getTime(arrow.getIndex)
      // note that between is inclusive
      time >= lower && time <= upper
    }
  }

  case class ArrowDictionaryEquals(i: Int, value: Int) extends Filter {
    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      arrow.getReader(i).asInstanceOf[ArrowDictionaryReader].getEncoded(arrow.getIndex) == value
    }
  }

  case class ArrowListDictionaryEquals(i: Int, value: Int) extends Filter {
    override def accept(visitor: FilterVisitor, extraData: AnyRef): AnyRef = extraData
    override def evaluate(o: AnyRef): Boolean = {
      val arrow = o.asInstanceOf[ArrowSimpleFeature]
      arrow.getReader(i).asInstanceOf[ArrowListDictionaryReader].getEncoded(arrow.getIndex).contains(value)
    }
  }
}
