import scipy.interpolate
import scipy.optimize
import argparse
import quspin.operators # Hamiltonians and operators
import quspin.basis # Hilbert space spin basis
import quspin.tools.misc
import numpy as np # generic math functions
import matplotlib.pyplot as plt # plotting devices
import matplotlib.cm
import matplotlib.colors
import math
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import sympy # symbolic calculations
import random # random numbers
import scipy.ndimage
import scipy.special
import scipy.linalg
import scipy.sparse
import scipy.interpolate
import scipy.optimize
import finufft # version of FFT where you can specify both positions and frequencies
# documentation for finufft: https://finufft.readthedocs.io/en/latest/index.html
import dill # reading and writing files
import os
import time
import numbers
import fnmatch
import matplotlib as mpl
import matplotlib.font_manager as font_manager
import itertools
import uncertainties as unc
import uncertainties.umath
import uncertainties.unumpy

class Data_to_save:
    def __init__(self, K, tau_th, L, realizations, alpha, eigenvalues=None, eigenstates=None):
        self.K = K
        self.tau_th = tau_th
        self.L = L
        self.realizations = realizations
        self.alpha = alpha
        self.eigenvalues = eigenvalues
        self.eigenstates = eigenstates

    def __eq__(self,other):
        return ((self.K==other.K) and (self.K==other.K))
    
data_empty = Data_to_save([],0,12,0,0,eigenvalues=[])

def calculate_r_mean(eigenvalues,return_r_list=False,targetE=0.5,adaptive_window=True, take_fraction=1):
    # targetE: energy around which r is calculated, rescaled into [0,1] interval. For POLFED data put targetE="POLFED"
    r_mean_list = []
    for E_list in eigenvalues:
        try:
            if isinstance(targetE, numbers.Number):
                E_middle = E_list[0] + targetE * (E_list[-1]-E_list[0])
                index_middle = np.argmin(np.abs(E_list-E_middle))
                if adaptive_window:
                    hws = math.ceil(min(250, (len(E_list))//20, index_middle, len(E_list)-1-index_middle) * take_fraction) # half of width of spectrum taken into consideration (250 in https://arxiv.org/pdf/1907.10336.pdf, but no more than 10% of spectrum), and two conditions to make sure we do not go out of range of array
                else:
                    hws = 250
                r_mean_list.append(quspin.tools.misc.mean_level_spacing(E_list[index_middle-hws:index_middle+hws])) #adds mean gap ratio for this realisation to list
            else:   #this will run for example for POLFED data
                E_len = len(E_list)
                i_min = math.floor((1-take_fraction)*E_len/2)
                i_max = math.ceil((1+take_fraction)*E_len/2)
                r_mean_list.append(quspin.tools.misc.mean_level_spacing(E_list[i_min:i_max]))
        except:
            if isinstance(targetE, numbers.Number):
                print("problem: ",targetE,index_middle,hws)
            else:
                print("problem with POLFED data (this error should never appear)")
    if len(r_mean_list) >=1:
        no_nan = np.isfinite(r_mean_list)
        r_mean = np.mean(np.array(r_mean_list)[no_nan])
    else:
        r_mean = float("nan")
        
    if return_r_list:
        return [r_mean,r_mean_list]
    else:
        return r_mean

parser = argparse.ArgumentParser(description='Precalculate datapoints of the rbar plots.')
parser.add_argument('fraction', type=float, help='What fraction of n_ev is kept (see https://arxiv.org/abs/2308.01073 for definition of n_ev).')
parser.add_argument('--s_tot_conservation', action='store_true')
#   code is not optimized for memory consumption, so it will require ~600 GB RAM to load all the data

args = parser.parse_args()

#######
# edit this block to produce plot
dir_str = "/nfs/clone13/home/pawlik/Quantum_Sun/data/"  # put here location of data for energies_and_entropies
N = 3
gaussian_matrix_type = "GOE"
s_tot_conservation = args.s_tot_conservation
effective_fraction_of_window = args.fraction
suffix=""
if not np.isclose(effective_fraction_of_window,1):
    suffix = "_effective_fraction_%.2f" %effective_fraction_of_window

prefix = ""
if gaussian_matrix_type!="GOE":
    prefix += gaussian_matrix_type + '_'
if N!=3:
    prefix += 'N=%i_' %(N)
if s_tot_conservation:
    prefix += 'conserved_'

print(f"Suffix: {suffix}")

alphas = np.sort(np.append(np.linspace(0.6,0.9,31),np.linspace(0.735,0.765,4)))

EnergyTargets = np.linspace(0.1,0.9,17)

force_new_files = False  # Use carefully
try:
    if force_new_files:
        raise ValueError
    if s_tot_conservation:
        sizes = range(7,17)
        with open(prefix+"curves_10_percent"+suffix, 'rb') as f:
            curves = dill.load(f)
            curves = np.array(curves)
        with open(prefix+"curves_10_percent_error"+suffix, 'rb') as f:
            curve_errors = dill.load(f)
            curve_errors = np.array(curve_errors)
        with open(prefix+"curves_10_percent_params"+suffix, 'rb') as f:
            alphas_old,sizes_old = dill.load(f)
    else:
        sizes = range(6,15)
        with open(prefix+"original_model_curves_10_percent"+suffix, 'rb') as f:
            curves = dill.load(f)
        with open(prefix+"original_model_curves_10_percent_error"+suffix, 'rb') as f:
            curve_errors = dill.load(f)
        with open(prefix+"original_model_curves_10_percent_params"+suffix, 'rb') as f:
            alphas_old,sizes_old = dill.load(f)
except:
    print("Files not found. Starting calculation from scratch.")
    if s_tot_conservation:
        sizes = range(7,17)
        alphas_old = alphas
        sizes_old = sizes
        curves = np.full((len(EnergyTargets),len(sizes),len(alphas)), float("nan"))
        curve_errors = np.full((len(EnergyTargets),len(sizes),len(alphas)), float("nan"))
    else:
        sizes = range(6,15)
        alphas_old = alphas
        sizes_old = sizes
        curves = np.full((len(EnergyTargets),len(sizes),len(alphas)), float("nan"))
        curve_errors = np.full((len(EnergyTargets),len(sizes),len(alphas)), float("nan"))


new_alpha_insert_indices = [i for i,item in enumerate(alphas) if item not in alphas_old]
new_alpha_insert_indices = [item - i for i,item in enumerate(new_alpha_insert_indices)]
new_size_insert_indices = [i for i,item in enumerate(sizes) if item not in sizes_old]
new_size_insert_indices = [item - i for i,item in enumerate(new_size_insert_indices)]

curves = np.insert(curves,new_alpha_insert_indices,float("nan"),axis=2)
curves = np.insert(curves,new_size_insert_indices,float("nan"),axis=1)
curve_errors = np.insert(curve_errors,new_alpha_insert_indices,float("nan"),axis=2)
curve_errors = np.insert(curve_errors,new_size_insert_indices,float("nan"),axis=1)

EnergyTargets = np.linspace(0.1,0.9,17)
marker_list = ['o','^','s','*','P','1','+','x','d']
color_list = ['blue', 'yellow', 'green', 'purple', 'red', 'gray', 'lime', 'fuchsia', 'cyan']

size_to_recalculate = range(1,25)

#######

datas = []
for ind_alpha,alpha in enumerate(alphas):
    _DummyList = []
    for ind_L, L in enumerate(sizes):
        MATCH = prefix + 'k_%i_*_alpha=%.3f' %(L,alpha)

        try:
            file_list = fnmatch.filter(os.listdir(dir_str),MATCH)
            realizations = [item.replace(MATCH.split("*")[0],"") for item in file_list]  # Number of realizations for each file. May need modification for different models!
            realizations = [item.replace(MATCH.split("*")[1],"") for item in realizations]
            realizations = [int(item) for item in realizations]
            max_realization_index = np.argmax(realizations)
            filename = MATCH.replace("*","%i" %realizations[max_realization_index])
            if L in size_to_recalculate:
                print(dir_str + filename)
                with open(dir_str + filename, 'rb') as f:
                    temp_data = dill.load(f)
                    print(L,temp_data.realizations)
                    _DummyList.append(temp_data)
            else:
                _DummyList.append(data_empty)
        except:
            _DummyList.append(data_empty)
    datas.append(_DummyList)

datas_POLFED = []
for ind_alpha,alpha in enumerate(alphas):
    _DummyList = []
    for ind_L, L in enumerate(sizes):
        try:
            if L in size_to_recalculate:
                dir_str = "/nfs/clone13/home/pawlik/Quantum_Sun/data/"  # put here location of data for energies_and_entropies
                with open(dir_str + "N=%i,L=%i/alpha=%.3f/" %(N,L,alpha) + prefix + "Entropies", 'rb') as f:
                    print(dir_str + "N=%i,L=%i/alpha=%.3f/" %(N,L,alpha) + prefix + "Entropies")
                    _DummyList.append(dill.load(f))
            else:
                _DummyList.append(data_empty)
        except:
            _DummyList.append(data_empty)
    datas_POLFED.append(_DummyList)

intersections = np.zeros((len(EnergyTargets), len(sizes)-1))
intersections = intersections.tolist()
for epsi in range(len(EnergyTargets)):
    for size in range(len(sizes)):
        print(EnergyTargets[epsi],sizes[size])
        if sizes[size] in size_to_recalculate:
            y = [calculate_r_mean(datas[alpha][size].eigenvalues,targetE=EnergyTargets[epsi],return_r_list=True,take_fraction=effective_fraction_of_window) for alpha in range(len(alphas))]
            y_POLFED = []
            num_of_realizations = [datas[ind_alpha][size].realizations for ind_alpha,alpha in enumerate(alphas)]
            for ind_alpha, alpha in enumerate(alphas):
                try:
                    epsilons_POLFED = np.array([item[0] for item in datas_POLFED[ind_alpha][size].eigenvalues])
                    if not datas_POLFED[ind_alpha][size].is_full_ED and (np.any(np.isclose(EnergyTargets[epsi],epsilons_POLFED))):
                        target_index = np.argmin(np.abs(epsilons_POLFED-EnergyTargets[epsi]))
                        num_of_realizations[ind_alpha] += len(datas_POLFED[ind_alpha][size].eigenvalues[target_index][1])
                        y_POLFED.append(calculate_r_mean(datas_POLFED[ind_alpha][size].eigenvalues[target_index][1],targetE="POLFED",return_r_list=True,take_fraction=effective_fraction_of_window))
                    else:
                        y_POLFED.append([float("nan"),[]])
                except Exception as e:  # if field "is_full_ED" does not exist, then it is certain that POLFED data does not exist
                    print(e)
                    y_POLFED.append([float("nan"),[]])
            y_error = [item[1] for item in y]
            y_error_POLFED = [item[1] for item in y_POLFED]
            y_error = [np.array(item + y_error_POLFED[ind],dtype=np.float64) for ind,item in enumerate(y_error)]
            y = [np.nanmean(item) for item in y_error]
            y_error = [item - y[item_index] for item_index,item in enumerate(y_error)]
            y_error = [item**2 for item in y_error]
            y_error = [np.nanmean(item) for item in y_error]
            y_error = [np.sqrt(item) for item in y_error]
            y_error = [item/np.sqrt(num_of_realizations[item_index]) for item_index,item in enumerate(y_error)]
            curves[epsi][size] = y
            curve_errors[epsi][size] = y_error

if s_tot_conservation:
    with open(prefix+"curves_10_percent"+suffix, 'wb') as file:
        dill.dump(curves, file)
    with open(prefix+"curves_10_percent_error"+suffix, 'wb') as file:
        dill.dump(curve_errors,file)
    params = [alphas,sizes]
    with open(prefix+"curves_10_percent_params"+suffix, 'wb') as file:
        dill.dump(params,file)
else:
    with open(prefix+"original_model_curves_10_percent"+suffix, 'wb') as file:
        dill.dump(curves, file)
    with open(prefix+"original_model_curves_10_percent_error"+suffix, 'wb') as file:
        dill.dump(curve_errors,file)
    params = [alphas,sizes]
    with open(prefix+"original_model_curves_10_percent_params"+suffix, 'wb') as file:
        dill.dump(params,file)

print("Done")