/***********************************************************************
 * Copyright (c) 2013-2025 Commonwealth Computer Research, Inc.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Apache License, Version 2.0
 * which accompanies this distribution and is available at
 * http://www.opensource.org/licenses/apache2.0.php.
 ***********************************************************************/

package org.locationtech.geomesa.fs.storage.parquet.io

import org.apache.hadoop.conf.Configuration
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.{FinalizedWriteContext, WriteContext}
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.geotools.api.feature.`type`.AttributeDescriptor
import org.geotools.api.feature.simple.{SimpleFeature, SimpleFeatureType}
import org.locationtech.geomesa.fs.storage.parquet.io.SimpleFeatureParquetSchema.GeoParquetMetadata.GeoParquetObserver
import org.locationtech.geomesa.utils.geotools.ObjectType
import org.locationtech.geomesa.utils.geotools.ObjectType.ObjectType
import org.locationtech.geomesa.utils.io.CloseWithLogging
import org.locationtech.geomesa.utils.text.WKBUtils
import org.locationtech.jts.geom._

import java.nio.ByteBuffer
import java.util.{Date, UUID}

class SimpleFeatureWriteSupport extends WriteSupport[SimpleFeature] {

  private var writer: SimpleFeatureWriteSupport.SimpleFeatureWriter = _
  private var consumer: RecordConsumer = _
  private var geoParquetObserver: GeoParquetObserver = _
  private var baseMetadata: java.util.Map[String, String] = _

  override val getName: String = "SimpleFeatureWriteSupport"

  // called once
  override def init(conf: Configuration): WriteContext = {
    val schema = SimpleFeatureParquetSchema.write(conf).getOrElse {
      throw new IllegalArgumentException("Could not extract SimpleFeatureType from write context")
    }
    this.writer = SimpleFeatureWriteSupport.SimpleFeatureWriter(schema.sft)
    this.geoParquetObserver = new GeoParquetObserver(schema.sft)
    this.baseMetadata = schema.metadata

    new WriteContext(schema.schema, schema.metadata)
  }

  // called per block
  override def prepareForWrite(recordConsumer: RecordConsumer): Unit = consumer = recordConsumer

  // called per row
  override def write(record: SimpleFeature): Unit = {
    writer.write(consumer, record)
    geoParquetObserver.write(record)
  }

  // called once at the end
  override def finalizeWrite(): FinalizedWriteContext = {
    try {
      val metadata = new java.util.HashMap[String, String]()
      metadata.putAll(baseMetadata)
      metadata.putAll(geoParquetObserver.metadata())
      new FinalizedWriteContext(metadata)
    } finally {
      CloseWithLogging(geoParquetObserver)
    }
  }
}

object SimpleFeatureWriteSupport {

  class SimpleFeatureWriter(attributes: Array[AttributeWriter[AnyRef]]) {

    private val fids = new FidWriter(attributes.length) // put the ID at the end of the record

    def write(consumer: RecordConsumer, value: SimpleFeature): Unit = {
      consumer.startMessage()
      var i = 0
      while (i < attributes.length) {
        attributes(i).apply(consumer, value.getAttribute(i))
        i += 1
      }
      fids.apply(consumer, value.getID)
      consumer.endMessage()
    }
  }

  object SimpleFeatureWriter {
    def apply(sft: SimpleFeatureType): SimpleFeatureWriter = {
      val attributes = Array.tabulate(sft.getAttributeCount)(i => attribute(sft.getDescriptor(i), i))
      new SimpleFeatureWriter(attributes.asInstanceOf[Array[AttributeWriter[AnyRef]]])
    }
  }

  def attribute(descriptor: AttributeDescriptor, index: Int): AttributeWriter[_] = {
    val bindings = ObjectType.selectType(descriptor.getType.getBinding, descriptor.getUserData)
    attribute(descriptor.getLocalName, index, bindings)
  }

  def attribute(name: String, index: Int, bindings: Seq[ObjectType]): AttributeWriter[_] = {
    bindings.head match {
      case ObjectType.GEOMETRY => geometry(name, index, bindings.last)
      case ObjectType.DATE     => new DateWriter(name, index)
      case ObjectType.STRING   => new StringWriter(name, index)
      case ObjectType.INT      => new IntegerWriter(name, index)
      case ObjectType.LONG     => new LongWriter(name, index)
      case ObjectType.FLOAT    => new FloatWriter(name, index)
      case ObjectType.DOUBLE   => new DoubleWriter(name, index)
      case ObjectType.BYTES    => new BytesWriter(name, index)
      case ObjectType.LIST     => new ListWriter(name, index, bindings(1))
      case ObjectType.MAP      => new MapWriter(name, index, bindings(1), bindings(2))
      case ObjectType.BOOLEAN  => new BooleanWriter(name, index)
      case ObjectType.UUID     => new UuidWriter(name, index)
      case _ => throw new IllegalArgumentException(s"Can't serialize field '$name' of type ${bindings.head}")
    }
  }

  // TODO support z/m
  private def geometry(name: String, index: Int, binding: ObjectType): AttributeWriter[_] = {
    binding match {
      case ObjectType.POINT           => new PointAttributeWriter(name, index)
      case ObjectType.LINESTRING      => new LineStringAttributeWriter(name, index)
      case ObjectType.POLYGON         => new PolygonAttributeWriter(name, index)
      case ObjectType.MULTIPOINT      => new MultiPointAttributeWriter(name, index)
      case ObjectType.MULTILINESTRING => new MultiLineStringAttributeWriter(name, index)
      case ObjectType.MULTIPOLYGON    => new MultiPolygonAttributeWriter(name, index)
      case ObjectType.GEOMETRY        => new GeometryWkbAttributeWriter(name, index)
      case _ => throw new IllegalArgumentException(s"Can't serialize field '$name' of type $binding")
    }
  }

  /**
    * Writes a simple feature attribute to a Parquet file
    */
  abstract class AttributeWriter[T <: AnyRef](name: String, index: Int) {

    /**
      * Writes a value to the current record
      *
      * @param consumer the Parquet record consumer
      * @param value value to write
      */
    def apply(consumer: RecordConsumer, value: T): Unit = {
      if (value != null) {
        consumer.startField(name, index)
        write(consumer, value)
        consumer.endField(name, index)
      }
    }

    protected def write(consumer: RecordConsumer, value: T): Unit
  }

  class FidWriter(index: Int) extends AttributeWriter[String](SimpleFeatureParquetSchema.FeatureIdField, index) {
    override protected def write(consumer: RecordConsumer, value: String): Unit =
      consumer.addBinary(Binary.fromString(value))
  }

  class DateWriter(name: String, index: Int) extends AttributeWriter[Date](name, index) {
    override protected def write(consumer: RecordConsumer, value: Date): Unit =
      consumer.addLong(value.getTime)
  }

  class DoubleWriter(name: String, index: Int) extends AttributeWriter[java.lang.Double](name, index) {
    override protected def write(consumer: RecordConsumer, value: java.lang.Double): Unit =
      consumer.addDouble(value)
  }

  class FloatWriter(name: String, index: Int) extends AttributeWriter[java.lang.Float](name, index) {
    override protected def write(consumer: RecordConsumer, value: java.lang.Float): Unit =
      consumer.addFloat(value)
  }

  class IntegerWriter(name: String, index: Int) extends AttributeWriter[java.lang.Integer](name, index) {
    override protected def write(consumer: RecordConsumer, value: java.lang.Integer): Unit =
      consumer.addInteger(value)
  }

  class LongWriter(name: String, index: Int) extends AttributeWriter[java.lang.Long](name, index) {
    override protected def write(consumer: RecordConsumer, value: java.lang.Long): Unit =
      consumer.addLong(value)
  }

  class StringWriter(name: String, index: Int) extends AttributeWriter[String](name, index) {
    override protected def write(consumer: RecordConsumer, value: String): Unit =
      consumer.addBinary(Binary.fromString(value))
  }

  class BytesWriter(name: String, index: Int) extends AttributeWriter[Array[Byte]](name, index) {
    override protected def write(consumer: RecordConsumer, value: Array[Byte]): Unit =
      consumer.addBinary(Binary.fromConstantByteArray(value))
  }

  class BooleanWriter(name: String, index: Int) extends AttributeWriter[java.lang.Boolean](name, index) {
    override protected def write(consumer: RecordConsumer, value: java.lang.Boolean): Unit =
      consumer.addBoolean(value)
  }

  class ListWriter(name: String, index: Int, valueType: ObjectType)
      extends AttributeWriter[java.util.List[AnyRef]](name, index) {

    private val elementWriter = attribute("element", 0, Seq(valueType)).asInstanceOf[AttributeWriter[AnyRef]]

    override protected def write(consumer: RecordConsumer, value: java.util.List[AnyRef]): Unit = {
      consumer.startGroup()
      if (!value.isEmpty) {
        consumer.startField("list", 0)
        val iter = value.iterator
        while (iter.hasNext) {
          consumer.startGroup()
          val item = iter.next
          if (item != null) {
            elementWriter(consumer, item)
          }
          consumer.endGroup()
        }
        consumer.endField("list", 0)
      }
      consumer.endGroup()
    }
  }

  class MapWriter(name: String, index: Int, keyType: ObjectType, valueType: ObjectType)
      extends AttributeWriter[java.util.Map[AnyRef, AnyRef]](name, index) {

    private val keyWriter = attribute("key", 0, Seq(keyType)).asInstanceOf[AttributeWriter[AnyRef]]
    private val valueWriter = attribute("value", 1, Seq(valueType)).asInstanceOf[AttributeWriter[AnyRef]]

    override protected def write(consumer: RecordConsumer, value: java.util.Map[AnyRef, AnyRef]): Unit = {
      consumer.startGroup()
      if (!value.isEmpty) {
        consumer.startField("map", 0)
        val iter = value.entrySet().iterator
        while (iter.hasNext) {
          val entry = iter.next()
          consumer.startGroup()
          keyWriter(consumer, entry.getKey)
          val v = entry.getValue
          if (v != null) {
            valueWriter(consumer, v)
          }
          consumer.endGroup()
        }
        consumer.endField("map", 0)
      }
      consumer.endGroup()
    }
  }

  class UuidWriter(name: String, index: Int) extends AttributeWriter[UUID](name, index) {
    override protected def write(consumer: RecordConsumer, value: UUID): Unit = {
      val bb = ByteBuffer.wrap(new Array[Byte](16))
      bb.putLong(value.getMostSignificantBits)
      bb.putLong(value.getLeastSignificantBits)
      consumer.addBinary(Binary.fromConstantByteArray(bb.array()))
    }
  }

  class PointAttributeWriter(name: String, index: Int) extends AttributeWriter[Point](name, index) {
    override def write(consumer: RecordConsumer, value: Point): Unit = {
      consumer.startGroup()
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.addDouble(value.getX)
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.addDouble(value.getY)
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.endGroup()
    }
  }

  class LineStringAttributeWriter(name: String, index: Int) extends AttributeWriter[LineString](name, index) {
    override def write(consumer: RecordConsumer, value: LineString): Unit = {
      consumer.startGroup()
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      var i = 0
      while (i < value.getNumPoints) {
        consumer.addDouble(value.getCoordinateN(i).x)
        i += 1
      }
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      i = 0
      while (i < value.getNumPoints) {
        consumer.addDouble(value.getCoordinateN(i).y)
        i += 1
      }
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.endGroup()
    }
  }

  class MultiPointAttributeWriter(name: String, index: Int) extends AttributeWriter[MultiPoint](name, index) {
    override def write(consumer: RecordConsumer, value: MultiPoint): Unit = {
      consumer.startGroup()
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      var i = 0
      while (i < value.getNumPoints) {
        consumer.addDouble(value.getGeometryN(i).asInstanceOf[Point].getX)
        i += 1
      }
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      i = 0
      while (i < value.getNumPoints) {
        consumer.addDouble(value.getGeometryN(i).asInstanceOf[Point].getY)
        i += 1
      }
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.endGroup()
    }
  }

  abstract class AbstractLinesWriter[T <: Geometry](name: String, index: Int)
      extends AttributeWriter[T](name, index) {

    protected def lines(value: T): Seq[LineString]

    override def write(consumer: RecordConsumer, value: T): Unit = {
      val lines = this.lines(value)
      consumer.startGroup()

      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.startGroup()
      consumer.startField("list", 0)
      lines.foreach { line =>
        consumer.startGroup()
        writeLineStringX(consumer, line)
        consumer.endGroup()
      }
      consumer.endField("list", 0)
      consumer.endGroup()
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnX, 0)

      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.startGroup()
      consumer.startField("list", 0)
      lines.foreach { line =>
        consumer.startGroup()
        writeLineStringY(consumer, line)
        consumer.endGroup()
      }
      consumer.endField("list", 0)
      consumer.endGroup()
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnY, 1)

      consumer.endGroup()
    }
  }

  class PolygonAttributeWriter(name: String, index: Int) extends AbstractLinesWriter[Polygon](name, index) {
    override protected def lines(value: Polygon): Seq[LineString] =
      Seq.tabulate(value.getNumInteriorRing + 1) { i =>
        if (i == 0) { value.getExteriorRing } else { value.getInteriorRingN(i - 1) }
      }
  }

  class MultiLineStringAttributeWriter(name: String, index: Int)
      extends AbstractLinesWriter[MultiLineString](name, index) {
    override protected def lines(value: MultiLineString): Seq[LineString] =
      Seq.tabulate(value.getNumGeometries)(i => value.getGeometryN(i).asInstanceOf[LineString])
  }

  class MultiPolygonAttributeWriter(name: String, index: Int) extends AttributeWriter[MultiPolygon](name, index) {
    override def write(consumer: RecordConsumer, value: MultiPolygon): Unit = {
      val polys = Seq.tabulate(value.getNumGeometries) { i =>
        val poly = value.getGeometryN(i).asInstanceOf[Polygon]
        Seq.tabulate(poly.getNumInteriorRing + 1) { i =>
          if (i == 0) { poly.getExteriorRing } else { poly.getInteriorRingN(i - 1) }
        }
      }
      consumer.startGroup()

      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnX, 0)
      consumer.startGroup()
      consumer.startField("list", 0)
      polys.foreach { lines =>
        consumer.startGroup()
        consumer.startField("element", 0)
        consumer.startGroup()
        consumer.startField("list", 0)
        lines.foreach { line =>
          consumer.startGroup()
          writeLineStringX(consumer, line)
          consumer.endGroup()
        }
        consumer.endField("list", 0)
        consumer.endGroup()
        consumer.endField("element", 0)
        consumer.endGroup()
      }
      consumer.endField("list", 0)
      consumer.endGroup()
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnX, 0)

      consumer.startField(SimpleFeatureParquetSchema.GeometryColumnY, 1)
      consumer.startGroup()
      consumer.startField("list", 0)
      polys.foreach { lines =>
        consumer.startGroup()
        consumer.startField("element", 0)
        consumer.startGroup()
        consumer.startField("list", 0)
        lines.foreach { line =>
          consumer.startGroup()
          writeLineStringY(consumer, line)
          consumer.endGroup()
        }
        consumer.endField("list", 0)
        consumer.endGroup()
        consumer.endField("element", 0)
        consumer.endGroup()
      }
      consumer.endField("list", 0)
      consumer.endGroup()
      consumer.endField(SimpleFeatureParquetSchema.GeometryColumnY, 1)

      consumer.endGroup()
    }
  }

  private def writeLineStringX(consumer: RecordConsumer, ring: LineString): Unit = {
    consumer.startField("element", 0)
    var i = 0
    while (i < ring.getNumPoints) {
      consumer.addDouble(ring.getCoordinateN(i).x)
      i += 1
    }
    consumer.endField("element", 0)
  }

  private def writeLineStringY(consumer: RecordConsumer, ring: LineString): Unit = {
    consumer.startField("element", 0)
    var i = 0
    while (i < ring.getNumPoints) {
      consumer.addDouble(ring.getCoordinateN(i).y)
      i += 1
    }
    consumer.endField("element", 0)
  }

  class GeometryWkbAttributeWriter(name: String, index: Int) extends AttributeWriter[Geometry](name, index) {
    override protected def write(consumer: RecordConsumer, value: Geometry): Unit =
      consumer.addBinary(Binary.fromConstantByteArray(WKBUtils.write(value)))
  }
}
