/***********************************************************************
 * Copyright (c) 2013-2025 General Atomics Integrated Intelligence, Inc.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Apache License, Version 2.0
 * which accompanies this distribution and is available at
 * https://www.apache.org/licenses/LICENSE-2.0
 ***********************************************************************/

package org.locationtech.geomesa.redis.data
package index

import org.geotools.api.filter.Filter
import org.locationtech.geomesa.index.api.QueryPlan.{FeatureReducer, ResultsToFeatures}
import org.locationtech.geomesa.index.api.{BoundedByteRange, FilterStrategy, QueryPlan}
import org.locationtech.geomesa.index.utils.Explainer
import org.locationtech.geomesa.index.utils.Reprojection.QueryReferenceSystems
import org.locationtech.geomesa.redis.data.util.RedisBatchScan
import org.locationtech.geomesa.utils.collection.{CloseableIterator, SelfClosingIterator}
import org.locationtech.geomesa.utils.io.WithClose
import redis.clients.jedis.{Jedis, Response, UnifiedJedis}

import java.nio.charset.StandardCharsets

sealed trait RedisQueryPlan extends QueryPlan[RedisDataStore] {

  override type Results = Array[Byte]

  /**
    * Tables being scanned
    *
    * @return
    */
  def tables: Seq[String]

  /**
    * Ranges being scanned
    *
    * @return
    */
  def ranges: Seq[BoundedByteRange]

  /**
    * Final filter applied to results
    *
    * @return
    */
  def ecql: Option[Filter]

  override def explain(explainer: Explainer, prefix: String = ""): Unit =
    RedisQueryPlan.explain(this, explainer, prefix)

  // additional explaining, if any
  protected def explain(explainer: Explainer): Unit = {}
}

object RedisQueryPlan {

  import org.locationtech.geomesa.filter.filterToString

  def explain(plan: RedisQueryPlan, explainer: Explainer, prefix: String): Unit = {
    explainer.pushLevel(s"${prefix}Plan: ${plan.getClass.getSimpleName}")
    explainer(s"Tables: ${plan.tables.mkString(", ")}")
    explainer(s"ECQL: ${plan.ecql.map(filterToString).getOrElse("none")}")
    explainer(s"Ranges (${plan.ranges.size}): ${plan.ranges.take(5).map(rangeToString).mkString(", ")}")
    plan.explain(explainer)
    explainer(s"Reduce: ${plan.reducer.getOrElse("none")}")
    explainer.popLevel()
  }

  private [data] def rangeToString(range: BoundedByteRange): String = {
    // based on accumulo's byte representation
    def printable(b: Byte): String = {
      val c = 0xff & b
      if (c >= 32 && c <= 126) { c.toChar.toString } else { f"%%$c%02x;" }
    }
    s"[${range.lower.map(printable).mkString("")}::${range.upper.map(printable).mkString("")}]"
  }

  // plan that will not actually scan anything
  case class EmptyPlan(filter: FilterStrategy, reducer: Option[FeatureReducer] = None) extends RedisQueryPlan {
    override val tables: Seq[String] = Seq.empty
    override val ranges: Seq[BoundedByteRange] = Seq.empty
    override val ecql: Option[Filter] = None
    override val resultsToFeatures: ResultsToFeatures[Array[Byte]] = ResultsToFeatures.empty
    override val sort: Option[Seq[(String, Boolean)]] = None
    override val maxFeatures: Option[Int] = None
    override val projection: Option[QueryReferenceSystems] = None
    override def scan(ds: RedisDataStore): CloseableIterator[Array[Byte]] = CloseableIterator.empty
  }

  // uses zrangebylex
  case class ZLexPlan(
      filter: FilterStrategy,
      tables: Seq[String],
      ranges: Seq[BoundedByteRange],
      pipeline: Boolean,
      ecql: Option[Filter], // note: will already be applied in resultsToFeatures
      resultsToFeatures: ResultsToFeatures[Array[Byte]],
      reducer: Option[FeatureReducer],
      sort: Option[Seq[(String, Boolean)]],
      maxFeatures: Option[Int],
      projection: Option[QueryReferenceSystems]
    ) extends RedisQueryPlan {

    import scala.collection.JavaConverters._

    override def scan(ds: RedisDataStore): CloseableIterator[Array[Byte]] = {
      val iter = tables.iterator.map(_.getBytes(StandardCharsets.UTF_8))
      val scans = iter.map(singleTableScan(ds, _))
      if (ds.config.queries.parallelPartitionScans) {
        // kick off all the scans at once
        scans.foldLeft(CloseableIterator.empty[Array[Byte]])(_ concat _)
      } else {
        // kick off the scans sequentially as they finish
        SelfClosingIterator(scans).flatMap(s => s)
      }
    }

    override protected def explain(explainer: Explainer): Unit =
      explainer(s"Pipelining: ${if (pipeline) { "enabled" } else { "disabled" }}")

    private def singleTableScan(ds: RedisDataStore, table: Array[Byte]): CloseableIterator[Array[Byte]] = {
      if (pipeline) {
        val result = Seq.newBuilder[Response[java.util.List[Array[Byte]]]]
        result.sizeHint(ranges.length)
        WithClose(ds.connection.getResource) {
          case jedis: Jedis =>
            WithClose(jedis.pipelined()) { pipe =>
              // note: use a foreach here to ensure the calls are all executing inside our close block
              ranges.foreach(range => result += pipe.zrangeByLex(table, range.lower, range.upper))
              pipe.sync()
            }
          case jedis: UnifiedJedis =>
            WithClose(jedis.pipelined()) { pipe =>
              // note: use a foreach here to ensure the calls are all executing inside our close block
              ranges.foreach(range => result += pipe.zrangeByLex(table, range.lower, range.upper))
              pipe.sync()
            }
        }
        result.result.iterator.flatMap(_.get.iterator().asScala)
      } else {
        RedisBatchScan(ds.connection, table, ranges, ds.config.queries.threads)
      }
    }
  }
}
