import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from scipy.special import binom
from scipy.constants import Boltzmann, eV

def config_entropy(n_sites, n, temp = 300):
    """Returns configurational entropy in [eV] for distributing n particles on n_sites gridpoints

    Parameters
    ---------
    n_sites: [int] number of sites
    n: [int] number of occupied sites
    temp: temperature [K]"""
    return -temp * Boltzmann * np.log(binom(n_sites, n)) / eV


def deltaG(U, energies, h2_energy, n_ini, n_fin, entropy = False):
    """Reaction Gibbs free energy

    Returns the Gibbs free energy at potential U [versus what??!] for the reaction
    n_ini * H_ads + (n_fin - n_ini) * (H+ + e-) --> n_fin * H_ads
    based on the computational hydrogen electrode method.

    Parameters
    ---------
    U: potential [vs. what??!]
    energies: numpy array containing total energies of [0H_ads, 1H_ads, ...]
    h2_energy: [float] total energy of H2
    n_ini: number[int] of initial H adsorbates
    n_fin: number[int] of final H adsorbates
    config_entropy: whether to use configurational entrop [bool]"""
    n_sites = len(energies) - 1
    delta_n = n_fin - n_ini
    e_ini = energies[n_ini]
    e_fin = energies[n_fin]
    if entropy:
        e_ini += config_entropy(n_sites, n_ini)
        e_fin += config_entropy(n_sites, n_fin)
    return e_fin - e_ini - delta_n * (h2_energy/2 - U)

def u_eq(energies, h2_energy, n_ini, n_fin, entropy = False):
    """Find equilibrium energy for transition from n_ini * H_ads to n_fin * H_ads

    Returns the equilibrium potential [vs. what??!] for the reaction
    n_ini * H_ads + (n_fin - n_ini) * (H+ + e-) --> n_fin * H_ads
    based on the computational hydrogen electrode method.

    Parameters
    ---------
    energies: numpy array containing total energies of [0H_ads, 1H_ads, ...]
    h2_energy: [float] total energy of H2
    n_ini: number[int] of initial H adsorbates
    n_fin: number[int] of final H adsorbates
    config_entropy: whether to use configurational entrop [bool]"""

    n_sites = len(energies) - 1
    delta_n = n_fin - n_ini
    e_ini = energies[n_ini]
    e_fin = energies[n_fin]
    if entropy:
        e_ini += config_entropy(n_sites, n_ini)
        e_fin += config_entropy(n_sites, n_fin)
    return -(e_fin - e_ini - delta_n * h2_energy/2) / delta_n

def get_theta_of_U(energies, h2_energy, entropy = False):
    """Find the coverage theta as function of U

    Returns array [[u_eq, (n+0.5)/n_sites]], for all n where u_eq is the equilibrium potential for the reaction
    n*H_ads + H+ + e- --> n_1*H_ads

    Parameters
    ---------
    energies: numpy array containing total energies of [0H_ads, 1H_ads, ...]
    config_entropy: whether to use cnfigurational entrop [bool]"""
    n_sites = len(energies) - 1
    solution = []
    n = 0
    while n < n_sites:
        u = -np.inf
        m = np.inf
        for n2 in range(n + 1, n_sites + 1):
            u_new = u_eq(energies, h2_energy, n, n2, entropy)
            if u_new > u:
                u = u_new
                m = n2
        solution.append([u, 0.5 * (n + m) / n_sites])
        n = m
    return np.array(solution)

def plot_all(urange, energies, h2_energy, entropy):
    """Plot Delta G for reactions: * + n * (H+ + e-) => nH*

    Parameters
    ----------
    urange: [np.array] defining the U values to plot at
    energies: [np.array] numpy array containing total energies of [0H_ads, 1H_ads, ...]
    h2_energy: [float] total energy of H2
    entropy:[bool] whether to include entropy"""
    plt.figure()
    n_sites = len(energies) - 1
    # Create 16 shades of blue using a colormap
    colors = [cm.Blues(i / (n_sites+2)) for i in range(1, (n_sites+2))]  # Avoid pure white at i=0
    # plot
    for n in np.arange(0, n_sites+1):
        plt.plot(urange, deltaG(urange, energies, h2_energy, 0, n, entropy), color = colors[n], label = "n={}".format(n))
    if entropy:
        plt.ylabel(r'$\Delta G$ [eV]')
    else:
        plt.ylabel(r'$\Delta E$ [eV]')
    plt.xlabel(r'$U$ vs. xxx') #ADJUST!
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.tight_layout()

def plot_theta_of_u(theta_of_u):
    """Plot the equilibrium coverage as a function of U"""
    plt.figure()
    plt.plot(theta_of_u[:,0], theta_of_u[:,1], '-x')
    plt.xlabel(r'$U$ vs. RHE [eV]')
    plt.ylabel(r'$\theta$')

def get_data():
    # find possible data directories to use
    data_dirs = [d for d in os.listdir(".") if os.path.isdir(d)]
    data_dirs = {index: value for index, value in enumerate(data_dirs)}
    folder_number = input(f"Select the data to analyze: {data_dirs}: ")
    path = data_dirs[int(folder_number)]
    # read in energies
    slab_energies = np.load(os.path.join(path, 'slab_energies.npy'))
    h2_energy = np.load(os.path.join(path, 'h2_energy.npy'))
    return slab_energies, h2_energy


if __name__ == "__main__":
    #set use of configurational entropy corrections true or false
    entropy = True

    # read data
    slab_energies, h2_energy = get_data()

    # Plot Delta G for reactions: * + n * (H+ + e-) => nH*
    urange = np.arange(0.2,0.8,0.01)
    plot_all(urange, slab_energies, h2_energy, entropy)

    #compute Theta as function of U
    theta_of_u = get_theta_of_U(slab_energies, h2_energy, entropy)

    # Plot Theta as function of U
    plot_theta_of_u(theta_of_u)

    #show plots
    plt.show()
