//
//  TextureWriteSpeed.swift
//  ChipMemorySpeed
//
//  Created by Myles C. Maxfield on 9/14/19.
//  Copyright © 2019 Myles C. Maxfield. All rights reserved.
//

import Foundation
import Metal

class ImageBlockSpeed: SpeedTest {
    var width: Int
    var height: Int
    var destinationTexture: MTLTexture
    var vertexBuffer: MTLBuffer
    var intermediateRenderPipelineState: MTLRenderPipelineState
    var destinationRenderPipelineState: MTLRenderPipelineState
    var renderPassDescriptor: MTLRenderPassDescriptor
    var commandQueue: MTLCommandQueue

    init(device: MTLDevice, width: Int, height: Int) {
        self.width = width
        self.height = height

        let library = device.makeDefaultLibrary()!
        let vertexFunction = library.makeFunction(name: "vertexFunction")!
        let intermediateFragmentFunction = library.makeFunction(name: "imageBlockIntermediateFragmentFunction")!
        let destinationFragmentFunction = library.makeFunction(name: "imageBlockDestinationFragmentFunction")!

        let destinationTextureDescriptor = MTLTextureDescriptor()
        destinationTextureDescriptor.pixelFormat = .r16Unorm
        destinationTextureDescriptor.width = width
        destinationTextureDescriptor.height = height
        destinationTextureDescriptor.usage = .renderTarget
        destinationTexture = device.makeTexture(descriptor: destinationTextureDescriptor)!

        let vertexData = [Float(-1), Float(-1), Float(-1), Float(1), Float(1), Float(-1), Float(1), Float(1)]
        vertexBuffer = device.makeBuffer(bytes: vertexData, length: MemoryLayout<Float>.size * vertexData.count, options: .storageModeShared)!

        let vertexDescriptor = MTLVertexDescriptor()
        vertexDescriptor.attributes[0].format = .float2
        vertexDescriptor.attributes[0].offset = 0
        vertexDescriptor.attributes[0].bufferIndex = 0
        vertexDescriptor.layouts[0].stride = MemoryLayout<Float>.size * 2

        let intermediateRenderPipelineDescriptor = MTLRenderPipelineDescriptor()
        intermediateRenderPipelineDescriptor.vertexFunction = vertexFunction
        intermediateRenderPipelineDescriptor.fragmentFunction = intermediateFragmentFunction
        intermediateRenderPipelineDescriptor.vertexDescriptor = vertexDescriptor
        intermediateRenderPipelineDescriptor.colorAttachments[0].pixelFormat = .r16Unorm
        intermediateRenderPipelineDescriptor.inputPrimitiveTopology = .triangle
        intermediateRenderPipelineState = try! device.makeRenderPipelineState(descriptor: intermediateRenderPipelineDescriptor)

        let destinationRenderPipelineDescriptor = MTLRenderPipelineDescriptor()
        destinationRenderPipelineDescriptor.vertexFunction = vertexFunction
        destinationRenderPipelineDescriptor.fragmentFunction = destinationFragmentFunction
        destinationRenderPipelineDescriptor.vertexDescriptor = vertexDescriptor
        destinationRenderPipelineDescriptor.colorAttachments[0].pixelFormat = .r16Unorm
        intermediateRenderPipelineDescriptor.inputPrimitiveTopology = .triangle
        destinationRenderPipelineState = try! device.makeRenderPipelineState(descriptor: destinationRenderPipelineDescriptor)
        
        let imageblockSampleLength = max(intermediateRenderPipelineState.imageblockSampleLength, destinationRenderPipelineState.imageblockSampleLength)

        renderPassDescriptor = MTLRenderPassDescriptor()
        renderPassDescriptor.colorAttachments[0].texture = destinationTexture
        renderPassDescriptor.colorAttachments[0].loadAction = .dontCare
        renderPassDescriptor.colorAttachments[0].storeAction = .store
        renderPassDescriptor.imageblockSampleLength = imageblockSampleLength

        commandQueue = device.makeCommandQueue()!
    }

    func runTest(_ callback: @escaping (CFTimeInterval) -> ()) {
        let commandBuffer = commandQueue.makeCommandBuffer()!

        let commandEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor)!

        commandEncoder.setRenderPipelineState(intermediateRenderPipelineState)
        commandEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
        commandEncoder.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4)

        commandEncoder.setRenderPipelineState(destinationRenderPipelineState)
        commandEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
        commandEncoder.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4)

        commandEncoder.endEncoding()
        commandBuffer.addCompletedHandler {commandBuffer in
            do {
                var pixelData = Array(repeating: UInt16(), count: self.width * self.height)
                self.destinationTexture.getBytes(&pixelData, bytesPerRow: self.width * MemoryLayout<UInt16>.size, from: MTLRegion(origin: MTLOrigin(x: 0, y: 0, z: 0), size: MTLSize(width: self.width, height: self.height, depth: 1)), mipmapLevel: 0)
                for pixelDatum in pixelData {
                    assert(pixelDatum == 57343)
                }
            }
            callback(commandBuffer.gpuEndTime - commandBuffer.gpuStartTime)
        }
        commandBuffer.commit()
    }
}
