use std::{collections::HashMap, convert::TryInto};
use wonnx::utils::{attribute, graph, model_with_opset, node, tensor};
mod common;

fn softmax_with_axis(x: &[f32], x_dims: &[i64], axis: i64, expected_y: &[f32], opset_version: i64) {
    let mut input_data = HashMap::new();
    input_data.insert("X".to_string(), x.into());

    // Model: X -> SoftMax -> Y
    let model = model_with_opset(
        graph(
            vec![tensor("X", x_dims)],
            vec![tensor("Y", x_dims)],
            vec![],
            vec![],
            vec![node(
                vec!["X"],
                vec!["Y"],
                "a",
                "Softmax",
                vec![attribute("axis", axis)],
            )],
        ),
        opset_version,
    );

    let session =
        pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

    let result = pollster::block_on(session.run(&input_data)).unwrap();
    common::assert_eq_vector((&result["Y"]).try_into().unwrap(), expected_y);
}

/// Test case from https://github.com/onnx/onnx/blob/61c36aa4250ca607c194fc55070e95db2b95761b/docs/Operators.md#examples-126
#[test]
fn test_softmax_simple() {
    let _ = env_logger::builder().is_test(true).try_init();
    softmax_with_axis(
        &[-1.0, 0.0, 1.0],
        &[1, 3],
        1,
        &[0.09003058, 0.24472848, 0.66524094],
        13,
    );
}

/// Test cases from https://github.com/onnx/onnx/blob/61c36aa4250ca607c194fc55070e95db2b95761b/docs/Operators.md#examples-126
#[test]
fn test_softmax_with_axis_opset7() {
    let _ = env_logger::builder().is_test(true).try_init();

    softmax_with_axis(
        &[0., 1., 2., 3., 10000., 10001., 10002., 10003.],
        &[2, 4],
        -1,
        &[
            0.032058604,
            0.08714432,
            0.23688284,
            0.6439143,
            0.032058604,
            0.08714432,
            0.23688284,
            0.6439143,
        ],
        13,
    );

    #[rustfmt::skip]
    let x_vals_3dims = [
        1.0856307, 0.99734545, 0.2829785, 1.5062947, 0.5786002,
        1.6514366, 2.4266791, 0.42891264, 1.2659363, 0.8667404,
        0.6788862, 0.09470897, 1.4913896, 0.638902, 0.44398195,
        0.43435127, 2.20593, 2.1867862, 1.004054, 0.3861864,

        0.7373686, 1.4907321, 0.9358339, 1.175829, 1.2538806,
        0.6377515, 0.9071052, 1.4286807, 0.14006872, 0.8617549,
        0.25561938, 2.798589, 1.7715331, 0.69987726, 0.92746246,
        0.17363568, 0.002845916, 0.6882227, 0.87953633, 0.28362733,

        0.8053665, 1.7276695, 0.3908998, 0.57380587, 0.33858904,
        0.011830495, 2.3923652, 0.41291216, 0.978736, 2.2381434,
        1.2940853, 1.0387882, 1.7437122, 0.79806274, 0.02968323,
        1.0693159, 0.8907064, 1.7548862, 1.4956441, 1.0693927
    ];

    // From https://github.com/microsoft/onnxruntime/blob/9c6cc018a9a71f2d3b36647b83ef60659ebb2a4c/onnxruntime/test/providers/cpu/math/softmax_test.cc#L142
    #[rustfmt::skip]
    let y_expected_axis0 = [
        0.01424391, 0.013040296, 0.0063832495, 0.021693084, 0.0085788425,
        0.02508162, 0.054455176, 0.007386185, 0.017058268, 0.011443698,
        0.009483798, 0.0052878284, 0.021372143, 0.009112078, 0.0074983323,
        0.0074264654, 0.04366858, 0.04284054, 0.01312807, 0.007077248,

        0.010054973, 0.021358095, 0.01226234, 0.015588412, 0.016853856,
        0.009101601, 0.01191507, 0.020073075, 0.0055332067, 0.0113867875,
        0.0062109763, 0.07898735, 0.028282179, 0.009684978, 0.012160114,
        0.005722092, 0.004823717, 0.009572758, 0.011591072, 0.006387392,

        0.010762473, 0.027068434, 0.007110686, 0.00853781, 0.0067482805,
        0.004867251, 0.0526183, 0.007268943, 0.0127998665, 0.045098193,
        0.017545262, 0.0135920765, 0.027506188, 0.010684152, 0.0049549243,
        0.01401341, 0.011721271, 0.027815264, 0.021463264, 0.014014485
    ];

    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], 0, &y_expected_axis0, 7);

    // From https://github.com/microsoft/onnxruntime/blob/9c6cc018a9a71f2d3b36647b83ef60659ebb2a4c/onnxruntime/test/providers/cpu/math/softmax_test.cc#L161
    #[rustfmt::skip]
    let y_expected_axis1 = [
        0.04113652, 0.037660476, 0.018434875, 0.0626498, 0.024775764,
        0.072435915, 0.15726697, 0.021331362, 0.049264412, 0.03304949,
        0.027389284, 0.015271291, 0.061722923, 0.026315752, 0.021655245,
        0.021447688, 0.1261152, 0.12372383, 0.03791397, 0.020439148,

        0.032693777, 0.069445916, 0.039871037, 0.05068577, 0.054800365,
        0.029593885, 0.03874189, 0.06526767, 0.01799124, 0.037024178,
        0.02019501, 0.25682762, 0.091959596, 0.031490736, 0.03953865,
        0.018605402, 0.01568433, 0.031125855, 0.037688408, 0.020768626,

        0.031088287, 0.0781894, 0.020539802, 0.024662167, 0.019492965,
        0.014059456, 0.15199229, 0.020996941, 0.036973465, 0.13026986,
        0.050680935, 0.03926183, 0.079453886, 0.030862054, 0.014312706,
        0.040478885, 0.033857856, 0.080346674, 0.06199841, 0.040481992
    ];

    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], 1, &y_expected_axis1, 7);

    // From https://github.com/microsoft/onnxruntime/blob/9c6cc018a9a71f2d3b36647b83ef60659ebb2a4c/onnxruntime/test/providers/cpu/math/softmax_test.cc#L218
    #[rustfmt::skip]
    let y_expected_axis2 = [
        0.22277209, 0.20394778, 0.09983283, 0.33927578, 0.13417149,
        0.21729809, 0.47177994, 0.06399124, 0.14778666, 0.099144064,
        0.1797734, 0.10023525, 0.40512702, 0.17272712, 0.14213723,
        0.06506401, 0.3825848, 0.37533033, 0.11501635, 0.062004484,

        0.13209775, 0.28059313, 0.16109712, 0.2047936, 0.22141843,
        0.1568978, 0.20539774, 0.3460294, 0.0953841, 0.19629094,
        0.045896534, 0.5836837, 0.20899355, 0.07156797, 0.08985819,
        0.15019783, 0.1266166, 0.2512731, 0.30425128, 0.16766116,

        0.17869644, 0.44943509, 0.11806339, 0.1417589, 0.112046175,
        0.03968324, 0.42900288, 0.059264507, 0.10435873, 0.36769062,
        0.23619612, 0.1829779, 0.37029108, 0.14383113, 0.0667037,
        0.15740506, 0.13165872, 0.31243387, 0.24108529, 0.15741715
    ];

    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], 2, &y_expected_axis2, 7);
    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], 2, &y_expected_axis2, 13);
    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], -1, &y_expected_axis2, 13);
}

#[test]
fn test_softmax_with_axis_opset13() {
    let _ = env_logger::builder().is_test(true).try_init();
    #[rustfmt::skip]
    let x_vals_3dims = [
        1.0856307, 0.99734545, 0.2829785, 1.5062947, 0.5786002,
        1.6514366, 2.4266791, 0.42891264, 1.2659363, 0.8667404,
        0.6788862, 0.09470897, 1.4913896, 0.638902, 0.44398195,
        0.43435127, 2.20593, 2.1867862, 1.004054, 0.3861864,

        0.7373686, 1.4907321, 0.9358339, 1.175829, 1.2538806,
        0.6377515, 0.9071052, 1.4286807, 0.14006872, 0.8617549,
        0.25561938, 2.798589, 1.7715331, 0.69987726, 0.92746246,
        0.17363568, 0.002845916, 0.6882227, 0.87953633, 0.28362733,

        0.8053665, 1.7276695, 0.3908998, 0.57380587, 0.33858904,
        0.011830495, 2.3923652, 0.41291216, 0.978736, 2.2381434,
        1.2940853, 1.0387882, 1.7437122, 0.79806274, 0.02968323,
        1.0693159, 0.8907064, 1.7548862, 1.4956441, 1.0693927
    ];

    // From https://github.com/microsoft/onnxruntime/blob/9c6cc018a9a71f2d3b36647b83ef60659ebb2a4c/onnxruntime/test/providers/cpu/math/softmax_test.cc#L191
    #[rustfmt::skip]
    let y_expected_axis1_opset13 = [
        0.253289, 0.11198013, 0.08185529, 0.35567388, 0.24795689,
        0.44600812, 0.46761957, 0.09471639, 0.2796827, 0.3307607,
        0.16864346, 0.04540785, 0.27406466, 0.14939913, 0.2167266,
        0.1320594, 0.3749925, 0.5493636, 0.21524426, 0.20455585,

        0.32341874, 0.18241648, 0.1747012, 0.36767146, 0.36021632,
        0.29275346, 0.10176494, 0.28598055, 0.13050734, 0.24336906,
        0.19977638, 0.67461985, 0.40293545, 0.22843185, 0.25989732,
        0.18405138, 0.04119869, 0.13638285, 0.27338937, 0.13651732,

        0.22807457, 0.2577944, 0.10201685, 0.15962972, 0.09529332,
        0.10314508, 0.5011263, 0.10428739, 0.23931651, 0.63683724,
        0.37181312, 0.12944824, 0.3946307, 0.19975942, 0.0699691,
        0.29696727, 0.11163106, 0.39906505, 0.4012943, 0.1979003
    ];

    // For the same input, opset-13's behavior is different from an earlier opset and we see different expected results for the same test input
    softmax_with_axis(&x_vals_3dims, &[3, 4, 5], 1, &y_expected_axis1_opset13, 13);
}
