import matplotlib as mpl;mpl.use('Agg');import sys;from multiprocessing import Pool;import h5py as h5
import numpy as np;import pandas as pd;import matplotlib.pyplot as plt;import os
import time;from subprocess import call;import mplhep as hep;hep.styles.use('ATLAS')

red="#c45955";blue="#28618f"
colors=np.array(['#023047ff','#126782ff','#219ebcff','#43b4baff','#fda10dff','#fb8500ff','#db6202ff','#bb3e03ff','#ae2012ff','#941b10ff'])[[1,2,6,7,8]]
longdash = (0, (1, 2, 1, 2, 4, 2))
def plot_vn_eta_alice(paths=[],exp_path_alice='./', lines=[],labels=[]):
    cents=['0_5','5_10','10_20','20_30','30_40','40_50']
    scale=0.8#for v24
    plt.figure(figsize=(14,8));rows=2;cols=round(len(cents)/2);plt.subplots_adjust(wspace=0,hspace=0)
    for i,cent in enumerate(cents[:]):
        plt.subplot(rows,cols,i+1)
        for j,path in enumerate(paths): 
            dat=np.loadtxt(os.path.join(path,cent,'v2_eta_charged_cumulant_etaref_-0.8_0.8.dat'))
            if lines[j]!='longdash':
                plt.plot(dat[:,0],dat[:,1],color=red,linestyle=lines[j], label=labels[j] if i==0 else '')
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=red, alpha=0.2,linestyle=lines[j])
                if i!=0: 
                    plt.plot(dat[:,0],scale*dat[:,3],color=blue,linestyle=lines[j])
                    plt.fill_between(dat[:,0],scale*(dat[:,3]-dat[:,4]),scale*(dat[:,3]+dat[:,4]),color=blue,alpha=0.2,linestyle=lines[j])
            else:
                plt.plot(dat[:,0],dat[:,1],color=red,linestyle=longdash, label=labels[j] if i==3 else '')
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=red, alpha=0.2,linestyle=longdash)
                if i!=0: 
                    plt.plot(dat[:,0],scale*dat[:,3],color=blue,linestyle=longdash)
                    plt.fill_between(dat[:,0],scale*(dat[:,3]-dat[:,4]),scale*(dat[:,3]+dat[:,4]),color=blue,alpha=0.2,linestyle=longdash)

        file=os.path.join(exp_path_alice,'%s.csv'%cent)
        dat=np.loadtxt(file,delimiter=',')
        plt.errorbar(dat[:17,0],dat[:17,1],yerr=[-dat[:17,3],dat[:17,2]],fmt='*',label=cent.replace('_','-')+' % '+r' $v_2\{2\}$',color=red)
        plt.bar(dat[:17,0], height=-dat[:17,5]+dat[:17,4], bottom=dat[:17,1]+dat[:17,5], width=0.05, align='center', facecolor='none',edgecolor=red)

        if i!=0:
            plt.errorbar(dat[17:34,0],scale*dat[17:34,1],yerr=[-scale*dat[17:34,3],scale*dat[17:34,2]],fmt='s',label=cent.replace('_','-')+' % '+r' $v_2\{4\}$*(%s)'%scale,color=blue)
            plt.bar(dat[17:34,0], height=scale*(-dat[17:34,5]+dat[17:34,4]), bottom=scale*(dat[17:34,1]+dat[17:34,5]), width=0.1, align='center', facecolor='none', edgecolor=blue)


        plt.xlim(-5,5);plt.ylim(0,0.13)
        if i!=3:
            plt.legend(ncol=1,fontsize=15)
        else:
            plt.legend(ncol=1,fontsize=15,loc='upper center',borderaxespad=-0.3)
        if i==1:
            plt.text(-4.5,0.112, r'$p_T>0 \mathrm{GeV}$',fontsize=15)
            plt.text(-4.5,0.100, r'$|\eta^{\rm ref}|<0.8$',fontsize=15)
        if i>2: 
            plt.xlabel(r'$\eta$',loc='center')
            plt.xticks([-4,-2,0,2,4])
        else:
            plt.xticks([-4,-2,0,2,4],['','','','',''])
        if i==0 or i==3: 
            plt.ylabel(r'$v_2$',loc='center')
            plt.yticks([0,0.05,0.1])
        else:
            plt.yticks([0,0.05,0.1],['','',''])
        
    plt.savefig('./v2_vs_eta_Alice.png',dpi=400,bbox_inches='tight')
    plt.savefig('./v2_vs_eta_Alice.pdf',dpi=400,bbox_inches='tight')
    plt.close()

def plot_vn_eta_cms_cum(paths=[],exp_path_cms='./', lines=[],labels=[]):
    cents=['0_5','5_10','10_15','20_25','30_35','40_50','50_60','60_70'][:-2]
    plt.figure(figsize=(14,8));rows=2;cols=round(len(cents)/2)  ;plt.subplots_adjust(wspace=0,hspace=0)
    
    scale=0.8#for v24
    for i,cent in enumerate(cents[:]):
        plt.subplot(rows,cols,i+1)
        for j, path in enumerate(paths):
            dat=np.loadtxt(os.path.join(path,cent,'v2_eta_charged_cumulant_etaref_-2.4_2.4.dat'),dtype=complex).real
            if lines[j]!='longdash':
                plt.plot(dat[:,0],dat[:,1],color=red,  linestyle=lines[j],label=labels[j] if i==0 else '')
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=red,alpha=0.2,linestyle=lines[j])
                if i!=0: 
                    plt.plot(dat[:,0],scale*dat[:,3],color=blue,linestyle=lines[j])
                    plt.fill_between(dat[:,0],scale*(dat[:,3]-dat[:,4]),scale*(dat[:,3]+dat[:,4]),color=blue,alpha=0.2,linestyle=lines[j])
            else:
                plt.plot(dat[:,0],dat[:,1],color=red,  linestyle=longdash,label=labels[j] if i==3 else '')
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=red,alpha=0.2,linestyle=longdash)
                if i!=0: 
                    plt.plot(dat[:,0],scale*dat[:,3],color=blue,linestyle=longdash)
                    plt.fill_between(dat[:,0],scale*(dat[:,3]-dat[:,4]),scale*(dat[:,3]+dat[:,4]),color=blue,alpha=0.2,linestyle=longdash)
        try:
            file=os.path.join(exp_path_cms,'%s.csv'%cent)
            dat=np.loadtxt(file,delimiter=',')
            plt.errorbar(dat[12:24,0],dat[12:24,1],yerr=[-dat[12:24,3],dat[12:24,2]],fmt='*',color=red,  label=cent.replace('_','-')+'% '+r'$v_2\{2\}$')
            plt.bar(dat[12:24,0], height=-dat[12:24,5]+dat[12:24,4], bottom=dat[12:24,1]+dat[12:24,5], width=0.05, facecolor='none', edgecolor=red)
            plt.errorbar(dat[24:36,0],scale*dat[24:36,1],yerr=[-scale*dat[24:36,3],scale*dat[24:36,2]],fmt='s',color=blue,label=cent.replace('_','-')+'% '+r'$v_2\{4\}$(*%s)'%scale)
            plt.bar(dat[24:36,0], height=scale*(-dat[24:36,5]+dat[24:36,4]), bottom=scale*(dat[24:36,1]+dat[24:36,5]), width=0.05, facecolor='none', edgecolor=blue)
        except:
            pass
            #plt.errorbar([],[],yerr=[[],[]],fmt='*',color=red,  label=cent.replace('_','-')+'% '+r'$v_2\{2\}$')
            #plt.errorbar([],[],yerr=[[],[]],fmt='s',color=blue,label=cent.replace('_','-')+'% '+r'$v_2\{4\}$')
        
        plt.xlim(-5,5);plt.ylim(0,0.13);plt.legend(fontsize=15)
        if i>2: 
            plt.xlabel(r'$\eta$',loc='center')
            plt.xticks([-4,-2,0,2,4])
        else:
            plt.xticks([-4,-2,0,2,4],['','','','',''])
        if i==0 or i==3: 
            plt.ylabel(r'$v_2$',loc='center')
            plt.yticks([0,0.05,0.1])
        else:
            plt.yticks([0,0.05,0.1],['','',''])
        if i==3:
            plt.text(-4.5,0.112, r'$0.3<p_T<3 \mathrm{GeV}$',fontsize=15)
            plt.text(-4.5,0.100, r'$|\eta^{\rm ref}|<2.4$',fontsize=15)

    plt.savefig('./v2_vs_eta_cms_cum.png',dpi=400,bbox_inches='tight')
    plt.savefig('./v2_vs_eta_cms_cum.pdf',dpi=400,bbox_inches='tight')
    plt.close()

def plot_vn_eta_cms_ep(paths=[],exp_path_cms='.', lines=[],labels=[]):
    cents=['0_5','5_10','10_15','20_25','30_35','40_50','50_60','60_70'][:-2]
    plt.figure(figsize=(14,8));rows=2;cols=round(len(cents)/2)  ;plt.subplots_adjust(wspace=0,hspace=0)

    for i,cent in enumerate(cents):
        plt.subplot(rows,cols,i+1)
        for j,path in enumerate(paths):
            dat=np.loadtxt(os.path.join(path,cent,'v2_eta_charged_ep.dat'),dtype=complex).real
            if lines[j]!='longdash':
                plt.plot(dat[:,0],dat[:,1],color=colors[j],linestyle=lines[j], label=labels[j] if i==0 else '')
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=colors[j],alpha=0.2,linestyle=lines[j])
            else:
                plt.plot(dat[:,0],dat[:,1],color=colors[j],linestyle=longdash, label=labels[j] if i==0 else '')               
                plt.fill_between(dat[:,0],dat[:,1]-dat[:,2],dat[:,1]+dat[:,2],color=colors[j],alpha=0.2,linestyle=longdash)

        try:
            file=os.path.join(exp_path_cms,'%s.csv'%cent)
            dat=np.loadtxt(file,delimiter=',')
            plt.errorbar(dat[:12,0],dat[:12,1],yerr=[-dat[:12,3],dat[:12,2]],fmt='*',color='black', label=cent.replace('_','-')+'% '+r'$v_2\{\text{EP}\}$')
            plt.bar(dat[:12,0], height=-dat[:12,5]+dat[:12,4], bottom=dat[:12,1]+dat[:12,5], width=0.05, align='center', facecolor='none', edgecolor='black')
        except:
            plt.errorbar([],[],yerr=[[],[]],fmt='*',color='black', label=cent.replace('_','-')+'% '+r'$v_2\{EP\}$')
        plt.xlim(-5,5);plt.ylim(0,0.13)
        if i>2: 
            plt.xlabel(r'$\eta$',loc='center')
            plt.xticks([-4,-2,0,2,4])
        else:
            plt.xticks([-4,-2,0,2,4],['','','','',''])
        if i==0 or i==3: 
            plt.ylabel(r'$v_2$',loc='center')
            plt.yticks([0,0.05,0.1])
        else:
            plt.yticks([0,0.05,0.1],['','',''])
        plt.legend(ncol=1,loc='best',fontsize=15)
        if i==1:
            plt.text(-3.6,0.112, r'$0.3<p_T^{ref}<3GeV$',fontsize=15)
            plt.text(-3.6,0.100, r'$3<|\eta^{ref}|<5$',fontsize=15)
    plt.savefig('./v2_vs_eta_cms_ep.png',dpi=400,bbox_inches='tight')
    plt.savefig('./v2_vs_eta_cms_ep.pdf',dpi=400,bbox_inches='tight')
    plt.close()

exp_path_alice='../exp_dat/zj/pbpb2760/vn_vs_eta_Alice'
exp_path_cms='../exp_dat/zj/pbpb2760/CMS_v2_eta_pbpb2760'

p1='./dat_w_err/nucleon/PbPb2760/'
p2='./dat_w_err/nucleon_fluct/PbPb2760/'
p3='./dat_w_err/hotspots/PbPb2760/'
p4='./dat_w_err/hotspots_fluct/PbPb2760/'
p5='./dat_w_err/hotspots_fluctmore/PbPb2760/'
labels=['nucleon', r'nucleon fluct($\sigma$=0.637)', 'hotspots', r'hotspots fluct($\sigma$=0.637)', r'hotspots fluct($\sigma$=1.2)']
lines=['solid', 'dashed', 'dotted', 'dashdot', 'longdash']

plot_vn_eta_cms_cum(paths=[p1,p2,p3,p4,p5],exp_path_cms=exp_path_cms,labels=labels,lines=lines)
plot_vn_eta_alice(paths=[p1,p2,p3,p4,p5],exp_path_alice=exp_path_alice,labels=labels,lines=lines)
#plot_vn_eta_cms_ep(paths=[p1,p2,p3,p4,p5],exp_path_cms=exp_path_cms,labels=labels,lines=lines)
