#!/usr/bin/env python3
import os
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from biom import load_table, Table
from biom.util import biom_open
from skbio.stats.composition import clr, centralize, closure
from skbio.stats.composition import clr_inv as softmax
import matplotlib.pyplot as plt
from scipy.stats import entropy, spearmanr
import click
from scipy.sparse import coo_matrix
import tensorflow as tf
from tensorflow.contrib.distributions import Multinomial, Normal
import datetime
from rhapsody.multimodal import MMvec, cross_validation
from rhapsody.util import onehot, rank_hits, random_multimodal, split_tables


@click.group()
def rhapsody():
    pass


@rhapsody.command()
@click.option('--microbe-file',
              help='Input microbial abundances')
@click.option('--metabolite-file',
              help='Input metabolite abundances')
@click.option('--metadata-file', default=None,
              help='Input sample metadata file')
@click.option('--training-column',
              help=('Column in the sample metadata specifying which '
                    'samples are for training and testing.'),
              default=None)
@click.option('--num-testing-examples',
              help=('Number of samples to randomly select for testing'),
              default=10)
@click.option('--min-feature-count',
              help=('Minimum number of samples a microbe needs to be observed '
                    'in order to not filter out'),
              default=10)
@click.option('--epochs',
              help='Number of epochs to train', default=10)
@click.option('--batch_size',
              help='Size of mini-batch', default=32)
@click.option('--latent_dim',
              help=('Dimensionality of shared latent space. '
                    'This is analogous to the number of PC axes.'),
              default=3)
@click.option('--input_prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--output_prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--arm-the-gpu', is_flag=True,
              help=('Enables GPU support'),
              default=False)
@click.option('--top-k',
              help=('Number of top hits to compare for cross-validation.'),
              default=50)
@click.option('--learning-rate',
              help=('Gradient descent decay rate.'),
              default=1e-1)
@click.option('--beta1',
              help=('Gradient decay rate for first Adam momentum estimates'),
              default=0.9)
@click.option('--beta2',
              help=('Gradient decay rate for second Adam momentum estimates'),
              default=0.95)
@click.option('--clipnorm',
              help=('Gradient clipping size.'),
              default=10.)
@click.option('--checkpoint-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-dir', default='summarydir',
              help='Summary directory to save cross validation results.')
@click.option('--ranks-file', default=None,
              help='Ranks file containing microbe-metabolite rankings.')
def mmvec(microbe_file, metabolite_file,
          metadata_file, training_column,
          num_testing_examples, min_feature_count,
         epochs, batch_size, latent_dim,
          input_prior, output_prior, arm_the_gpu, top_k,
          learning_rate, beta1, beta2, clipnorm,
          checkpoint_interval, summary_interval,
          summary_dir, ranks_file):

    microbes = load_table(microbe_file)
    metabolites = load_table(metabolite_file)

    if metadata_file is not None:
        metadata = pd.read_table(metadata_file, index_col=0)
    else:
        metadata = None

    res = split_tables(
        microbes, metabolites,
        metadata=metadata, training_column=training_column,
        num_test=num_testing_examples,
        min_samples=min_feature_count)

    (train_microbes_df, test_microbes_df,
     train_metabolites_df, test_metabolites_df) = res

    # filter out low abundance microbes
    microbe_ids = microbes.ids(axis='observation')
    metabolite_ids = metabolites.ids(axis='observation')

    params = []

    sname = 'latent_dim_' + str(latent_dim) + \
           '_input_prior_%.2f' % input_prior + \
           '_output_prior_%.2f' % output_prior + \
           '_beta1_%.2f' % beta1 + \
           '_beta2_%.2f' % beta2

    sname = os.path.join(summary_dir, sname)

    n, d1 = microbes.shape
    n, d2 = metabolites.shape

    train_microbes_coo = coo_matrix(train_microbes_df.values)
    test_microbes_coo = coo_matrix(test_microbes_df.values)

    if arm_the_gpu:
        # pick out the first GPU
        device_name='/device:GPU:0'
    else:
        device_name='/cpu:0'

    config = tf.ConfigProto()
    with tf.Graph().as_default(), tf.Session(config=config) as session:
        model = MMvec(
            latent_dim=latent_dim,
            u_scale=input_prior, v_scale=output_prior,
            learning_rate = learning_rate,
            beta_1=beta1, beta_2=beta2,
            device_name=device_name,
            clipnorm=clipnorm, save_path=sname)

        model(session,
              train_microbes_coo, train_metabolites_df.values,
              test_microbes_coo, test_metabolites_df.values)

        loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval,
                             checkpoint_interval=checkpoint_interval)

        U, V = model.U, model.V
        d1 = U.shape[0]

        U_ = np.hstack(
            (np.ones((model.U.shape[0], 1)),
             model.Ubias.reshape(-1, 1), U)
        )
        V_ = np.vstack(
            (model.Vbias.reshape(1, -1),
             np.ones((1, model.V.shape[1])), V)
        )

        np.savetxt(os.path.join(summary_dir, 'U.txt'), model.U)
        np.savetxt(os.path.join(summary_dir, 'V.txt'), model.V)
        np.savetxt(os.path.join(summary_dir, 'Ubias.txt'), model.Ubias)
        np.savetxt(os.path.join(summary_dir, 'Vbias.txt'), model.Vbias)

        if ranks_file is not None:
            ranks = pd.DataFrame(
                clr(softmax(np.hstack(
                    (np.zeros((model.U.shape[0], 1)), U_ @ V_)))),
                index=train_microbes_df.columns,
                columns=train_metabolites_df.columns)

            # shift the reference from the first microbe to the average microbe
            ranks = ranks - ranks.mean(axis=1)

            params, rank_stats = cross_validation(
                model, test_microbes_df, test_metabolites_df, top_N=top_k)

            params.to_csv(os.path.join(summary_dir, 'model_results.csv'))
            rank_stats.to_csv(os.path.join(summary_dir, 'otu_cv_results.csv'))
            ranks.to_csv(ranks_file)


if __name__ == '__main__':
    rhapsody()
