/***********************************************************************
 * Copyright (c) 2013-2025 General Atomics Integrated Intelligence, Inc.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Apache License, Version 2.0
 * which accompanies this distribution and is available at
 * https://www.apache.org/licenses/LICENSE-2.0
 ***********************************************************************/

package org.locationtech.geomesa.convert2.transforms

import org.locationtech.geomesa.convert.EvaluationContext
import org.locationtech.geomesa.convert.EvaluationContext.{ContextDependent, FieldAccessor, NullFieldAccessor}
import org.locationtech.geomesa.convert2.Field

import scala.util.Try

sealed trait Expression extends ContextDependent[Expression] {

  /**
   * Evaluate the expression against an input row
   *
   * @param args arguments
   * @return
   */
  def apply(args: Array[_ <: AnyRef]): AnyRef

  /**
    * Gets the field dependencies that this expr relies on
    *
    * @param stack current field stack, used to detect circular dependencies
    * @param fieldMap fields lookup
    * @return dependencies
    */
  def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field]

  /**
    * Any nested expressions
    *
    * @return
    */
  def children(): Seq[Expression] = Seq.empty

  /**
   * Visitor pattern for processing an expression tree
   *
   * @param visitor visitor
   * @return
   */
  def accept[T](visitor: ExpressionVisitor[T]): T
}

object Expression {

  def apply(e: String): Expression = ExpressionParser.parse(e)

  /**
    * Returns the list of unique expressions in the input, including any descendants
    *
    * @param expressions expressions
    * @return
    */
  def flatten(expressions: Seq[Expression]): Seq[Expression] = {
    val toCheck = scala.collection.mutable.Queue(expressions: _*)
    val result = scala.collection.mutable.Set.empty[Expression]
    while (toCheck.nonEmpty) {
      val next = toCheck.dequeue()
      if (result.add(next)) {
        toCheck ++= next.children()
      }
    }
    result.toSeq
  }

  sealed trait Literal[T <: AnyRef] extends Expression {
    def value: T
    override def apply(args: Array[_ <: AnyRef]): AnyRef = value
    override def withContext(ec: EvaluationContext): Expression = this
    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] = Set.empty
    override def toString: String = String.valueOf(value)
    override def accept[V](visitor: ExpressionVisitor[V]): V = visitor.visit(this)
  }

  case class LiteralString(value: String) extends Literal[String] {
    override def toString: String = s"'${String.valueOf(value)}'"
  }

  case class LiteralInt(value: Integer) extends Literal[Integer]

  case class LiteralLong(value: java.lang.Long) extends Literal[java.lang.Long]

  case class LiteralFloat(value: java.lang.Float) extends Literal[java.lang.Float]

  case class LiteralDouble(value: java.lang.Double) extends Literal[java.lang.Double]

  case class LiteralBoolean(value: java.lang.Boolean) extends Literal[java.lang.Boolean]

  case class LiteralAny(value: AnyRef) extends Literal[AnyRef]

  case object LiteralNull extends Literal[AnyRef] { override def value: AnyRef = null }

  abstract class CastExpression(e: Expression, binding: String) extends Expression {
    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] =
      e.dependencies(stack, fieldMap)
    override def children(): Seq[Expression] = Seq(e)
    override def toString: String = s"$e::$binding"
  }

  case class CastToInt(e: Expression) extends CastExpression(e, "int") {
    override def apply(args: Array[_ <: AnyRef]): Integer = {
      e.apply(args) match {
        case n: Integer          => n
        case n: java.lang.Number => n.intValue()
        case n: String           => n.toInt
        case n: AnyRef           => n.toString.toInt
        case null                => throw new NullPointerException("Trying to cast 'null' to int")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToInt(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class CastToLong(e: Expression) extends CastExpression(e, "long") {
    override def apply(args: Array[_ <: AnyRef]): java.lang.Long = {
      e.apply(args) match {
        case n: java.lang.Long   => n
        case n: java.lang.Number => n.longValue()
        case n: String           => n.toLong
        case n: AnyRef           => n.toString.toLong
        case null                => throw new NullPointerException("Trying to cast 'null' to long")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToLong(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class CastToFloat(e: Expression) extends CastExpression(e, "float") {
    override def apply(args: Array[_ <: AnyRef]): java.lang.Float = {
      e.apply(args) match {
        case n: java.lang.Float  => n
        case n: java.lang.Number => n.floatValue()
        case n: String           => n.toFloat
        case n: AnyRef           => n.toString.toFloat
        case null                => throw new NullPointerException("Trying to cast 'null' to float")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToFloat(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class CastToDouble(e: Expression) extends CastExpression(e, "double") {
    override def apply(args: Array[_ <: AnyRef]): java.lang.Double = {
      e.apply(args) match {
        case n: java.lang.Double => n
        case n: java.lang.Number => n.doubleValue()
        case n: String           => n.toDouble
        case n: AnyRef           => n.toString.toDouble
        case null                => throw new NullPointerException("Trying to cast 'null' to double")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToDouble(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class CastToBoolean(e: Expression) extends CastExpression(e, "boolean") {
    override def apply(args: Array[_ <: AnyRef]): java.lang.Boolean = {
      e.apply(args) match {
        case b: java.lang.Boolean => b
        case b: String            => b.toBoolean
        case b: AnyRef            => b.toString.toBoolean
        case null                 => throw new NullPointerException("Trying to cast 'null' to boolean")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToBoolean(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class CastToString(e: Expression) extends CastExpression(e, "string") {
    override def apply(args: Array[_ <: AnyRef]): String = {
      e.apply(args) match {
        case s: String => s
        case s: AnyRef => s.toString
        case null      => throw new NullPointerException("Trying to cast 'null' to String")
      }
    }
    override def withContext(ec: EvaluationContext): Expression = {
      val ewc = e.withContext(ec)
      if (e.eq(ewc)) { this } else { CastToString(ewc) }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
  }

  case class Column(i: Int) extends Expression {
    override def apply(args: Array[_ <: AnyRef]): AnyRef = args(i)
    override def withContext(ec: EvaluationContext): Expression = this
    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] = Set.empty
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"$$$i"
  }

  case class FieldLookup(n: String, accessor: FieldAccessor = NullFieldAccessor) extends Expression {
    override def apply(args: Array[_ <: AnyRef]): AnyRef = accessor.apply()
    override def withContext(ec: EvaluationContext): Expression = FieldLookup(n, ec.accessor(n))
    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] = {
      fieldMap.get(n) match {
        case None => Set.empty
        case Some(field) =>
          if (stack.contains(field)) {
            throw new IllegalArgumentException(s"Cyclical dependency detected in field $field")
          } else {
            field.transforms.toSeq.flatMap(_.dependencies(stack + field, fieldMap)).toSet + field
          }
      }
    }
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"$$$n"
  }

  case class RegexExpression(s: String) extends Expression {
    private val compiled = s.r
    override def apply(args: Array[_ <: AnyRef]): AnyRef = compiled
    override def withContext(ec: EvaluationContext): Expression = this
    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] = Set.empty
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"$s::r"
  }

  case class FunctionExpression(f: TransformerFunction, arguments: Array[Expression]) extends Expression {

    @volatile private var contextDependent: Int = -1

    private def this(f: TransformerFunction, arguments: Array[Expression], contextDependent: Int) = {
      this(f, arguments)
      this.contextDependent = contextDependent
    }

    override def apply(args: Array[_ <: AnyRef]): AnyRef = f.apply(arguments.map(_.apply(args)))

    override def withContext(ec: EvaluationContext): Expression = {
      // this code is thread-safe, in that it will ensure correctness, but does not guarantee
      // that the dependency check is only performed once
      if (contextDependent == 0) { this } else {
        lazy val fwc = f.withContext(ec)
        lazy val awc = arguments.map(_.withContext(ec))
        if (contextDependent == 1) {
          new FunctionExpression(fwc, awc, 1)
        } else {
          if (!fwc.eq(f)) {
            contextDependent = 1
          } else {
            var i = 0
            while (i < arguments.length) {
              if (!awc(i).eq(arguments(i))) {
                contextDependent = 1
                i = Int.MaxValue
              } else {
                i += 1
              }
            }
            if (i == arguments.length) {
              contextDependent = 0
            }
          }
          if (contextDependent == 0) { this } else { new FunctionExpression(fwc, awc, 1) }
        }
      }
    }

    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] =
      arguments.flatMap(_.dependencies(stack, fieldMap)).toSet
    override def children(): Seq[Expression] = arguments
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"${f.names.head}${arguments.mkString("(", ",", ")")}"
  }

  case class TryExpression(toTry: Expression, fallback: Expression) extends Expression {

    @volatile private var contextDependent: Int = -1

    private def this(toTry: Expression, fallback: Expression, contextDependent: Int) = {
      this(toTry, fallback)
      this.contextDependent = contextDependent
    }

    override def apply(args: Array[_ <: AnyRef]): AnyRef = Try(toTry.apply(args)).getOrElse(fallback.apply(args))

    override def withContext(ec: EvaluationContext): Expression = {
      // this code is thread-safe, in that it will ensure correctness, but does not guarantee
      // that the dependency check is only performed once
      if (contextDependent == 0) { this } else {
        lazy val twc = toTry.withContext(ec)
        lazy val fwc = fallback.withContext(ec)
        if (contextDependent == 1) {
          new TryExpression(twc, fwc, 1)
        } else {
          contextDependent = if (twc.eq(toTry) && fwc.eq(fallback)) { 0 } else { 1 }
          if (contextDependent == 0) { this } else { new TryExpression(twc, fwc, 1) }
        }
      }
    }

    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] =
      toTry.dependencies(stack, fieldMap) ++ fallback.dependencies(stack, fieldMap)
    override def children(): Seq[Expression] = Seq(toTry, fallback)
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"try($toTry,$fallback)"
  }

  case class WithDefaultExpression(expressions: Seq[Expression]) extends Expression {

    require(expressions.lengthCompare(1) > 0)

    @volatile private var contextDependent: Int = -1

    private def this(expressions: Seq[Expression], contextDependent: Int) = {
      this(expressions)
      this.contextDependent = contextDependent
    }

    override def apply(args: Array[_ <: AnyRef]): AnyRef = {
      expressions.foreach { e =>
        val result = e.apply(args)
        if (result != null) {
          return result
        }
      }
      null
    }

    override def withContext(ec: EvaluationContext): Expression = {
      // this code is thread-safe, in that it will ensure correctness, but does not guarantee
      // that the dependency check is only performed once
      if (contextDependent == 0) { this } else {
        lazy val ewc = expressions.map(_.withContext(ec))
        if (contextDependent == 1) {
          new WithDefaultExpression(ewc, 1)
        } else {
          contextDependent = if (ewc.zip(expressions).forall { case (e1, e2) => e1.eq(e2) }) { 0 } else { 1 }
          if (contextDependent == 0) { this } else { new WithDefaultExpression(ewc, 1) }
        }
      }
    }

    override def dependencies(stack: Set[Field], fieldMap: Map[String, Field]): Set[Field] =
      expressions.flatMap(_.dependencies(stack, fieldMap)).toSet
    override def children(): Seq[Expression] = expressions
    override def accept[T](visitor: ExpressionVisitor[T]): T = visitor.visit(this)
    override def toString: String = s"withDefault(${expressions.mkString(",")})"
  }
}
