//
//  TextureWriteSpeed.swift
//  ChipMemorySpeed
//
//  Created by Myles C. Maxfield on 9/14/19.
//  Copyright © 2019 Myles C. Maxfield. All rights reserved.
//

import Foundation
import Metal

class TextureWriteSpeed: SpeedTest {
    var width: Int
    var height: Int
    var intermediateTexture: MTLTexture
    var destinationTexture: MTLTexture
    var vertexBuffer: MTLBuffer
    var intermediateRenderPipelineState: MTLRenderPipelineState
    var destinationRenderPipelineState: MTLRenderPipelineState
    var intermediateRenderPassDescriptor: MTLRenderPassDescriptor
    var destinationRenderPassDescriptor: MTLRenderPassDescriptor
    var commandQueue: MTLCommandQueue

    init(device: MTLDevice, width: Int, height: Int) {
        self.width = width
        self.height = height

        let library = device.makeDefaultLibrary()!
        let vertexFunction = library.makeFunction(name: "vertexFunction")!
        let intermediateFragmentFunction = library.makeFunction(name: "textureWriteIntermediateFragmentFunction")!
        let destinationFragmentFunction = library.makeFunction(name: "textureWriteDestinationFragmentFunction")!

        let intermediateTextureDescriptor = MTLTextureDescriptor()
        intermediateTextureDescriptor.pixelFormat = .rgba32Float
        intermediateTextureDescriptor.width = width
        intermediateTextureDescriptor.height = height
        intermediateTextureDescriptor.usage = [.shaderRead, .shaderWrite]
        intermediateTexture = device.makeTexture(descriptor: intermediateTextureDescriptor)!

        let destinationTextureDescriptor = MTLTextureDescriptor()
        destinationTextureDescriptor.pixelFormat = .r16Unorm
        destinationTextureDescriptor.width = width
        destinationTextureDescriptor.height = height
        destinationTextureDescriptor.usage = .renderTarget
        destinationTexture = device.makeTexture(descriptor: destinationTextureDescriptor)!

        let vertexData = [Float(-1), Float(-1), Float(-1), Float(1), Float(1), Float(-1), Float(1), Float(1)]
        vertexBuffer = device.makeBuffer(bytes: vertexData, length: MemoryLayout<Float>.size * vertexData.count, options: .storageModeShared)!

        let vertexDescriptor = MTLVertexDescriptor()
        vertexDescriptor.attributes[0].format = .float2
        vertexDescriptor.attributes[0].offset = 0
        vertexDescriptor.attributes[0].bufferIndex = 0
        vertexDescriptor.layouts[0].stride = MemoryLayout<Float>.size * 2

        let intermediateRenderPipelineDescriptor = MTLRenderPipelineDescriptor()
        intermediateRenderPipelineDescriptor.vertexFunction = vertexFunction
        intermediateRenderPipelineDescriptor.fragmentFunction = intermediateFragmentFunction
        intermediateRenderPipelineDescriptor.vertexDescriptor = vertexDescriptor
        intermediateRenderPipelineDescriptor.colorAttachments[0].pixelFormat = .invalid
        intermediateRenderPipelineDescriptor.inputPrimitiveTopology = .triangle
        intermediateRenderPipelineState = try! device.makeRenderPipelineState(descriptor: intermediateRenderPipelineDescriptor)

        let destinationRenderPipelineDescriptor = MTLRenderPipelineDescriptor()
        destinationRenderPipelineDescriptor.vertexFunction = vertexFunction
        destinationRenderPipelineDescriptor.fragmentFunction = destinationFragmentFunction
        destinationRenderPipelineDescriptor.vertexDescriptor = vertexDescriptor
        destinationRenderPipelineDescriptor.colorAttachments[0].pixelFormat = .r16Unorm
        intermediateRenderPipelineDescriptor.inputPrimitiveTopology = .triangle
        destinationRenderPipelineState = try! device.makeRenderPipelineState(descriptor: destinationRenderPipelineDescriptor)

        intermediateRenderPassDescriptor = MTLRenderPassDescriptor()
        intermediateRenderPassDescriptor.renderTargetWidth = width
        intermediateRenderPassDescriptor.renderTargetHeight = height
        intermediateRenderPassDescriptor.defaultRasterSampleCount = 1

        destinationRenderPassDescriptor = MTLRenderPassDescriptor()
        destinationRenderPassDescriptor.colorAttachments[0].texture = destinationTexture
        destinationRenderPassDescriptor.colorAttachments[0].loadAction = .dontCare
        destinationRenderPassDescriptor.colorAttachments[0].storeAction = .store

        commandQueue = device.makeCommandQueue()!
    }

    func runTest(_ callback: @escaping (CFTimeInterval) -> ()) {
        let uniforms = [UInt32(width), UInt32(height)]

        let commandBuffer = commandQueue.makeCommandBuffer()!

        let commandEncoder1 = commandBuffer.makeRenderCommandEncoder(descriptor: intermediateRenderPassDescriptor)!
        commandEncoder1.setRenderPipelineState(intermediateRenderPipelineState)
        commandEncoder1.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
        commandEncoder1.setFragmentTexture(intermediateTexture, index: 0)
        commandEncoder1.setFragmentBytes(uniforms, length: MemoryLayout<UInt32>.size * uniforms.count, index: 0)
        commandEncoder1.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4)
        commandEncoder1.endEncoding()
        
        let commandEncoder2 = commandBuffer.makeRenderCommandEncoder(descriptor: destinationRenderPassDescriptor)!
        commandEncoder2.setRenderPipelineState(destinationRenderPipelineState)
        commandEncoder2.setVertexBuffer(vertexBuffer, offset: 0, index: 0)
        commandEncoder2.setFragmentTexture(intermediateTexture, index: 0)
        commandEncoder2.setFragmentBytes(uniforms, length: MemoryLayout<UInt32>.size * uniforms.count, index: 0)
        commandEncoder2.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4)

        commandEncoder2.endEncoding()
        commandBuffer.addCompletedHandler {commandBuffer in
            do {
                var pixelData = Array(repeating: UInt16(), count: self.width * self.height)
                self.destinationTexture.getBytes(&pixelData, bytesPerRow: self.width * MemoryLayout<UInt16>.size, from: MTLRegion(origin: MTLOrigin(x: 0, y: 0, z: 0), size: MTLSize(width: self.width, height: self.height, depth: 1)), mipmapLevel: 0)
                for pixelDatum in pixelData {
                    assert(pixelDatum == 57343)
                }
            }
            callback(commandBuffer.gpuEndTime - commandBuffer.gpuStartTime)
        }
        commandBuffer.commit()
    }
}
