import os, sys
import matplotlib.pyplot as plt
import ase
import ase.io
from ase.data import covalent_radii as radii
from ase.data.colors import jmol_colors
from matplotlib.patches import Circle
import numpy as np
sys.path.append(os.path.dirname(__file__))
import geoms_from_paper


def load_data(path, filename):
    data = np.loadtxt(os.path.join(path,  filename), delimiter=' ')
    return data

def plot_atoms(name):
    slab = ase.io.read(name+".traj")
    # Create a secondary y-axis
    ax1=plt.gca()
    ax2 = ax1.twinx()
    ax2.set_aspect('equal')
    for atom in slab:
        #plt.plot(atom.z, 0.5, 'k.')
        color = jmol_colors[atom.number]
        radius = radii[atom.number]
        circle = Circle((atom.z, atom.x), radius, facecolor=color,
                                edgecolor='k', linewidth=1)
        ax2.add_patch(circle)
    ax2.axes.get_yaxis().set_visible(False) # Hides y-axis

def plot_solvation_jellum(name):
    path = "sjm_"+ name
    plt.figure()
    solvent = load_data(path, 'cavity.txt')
    jellium = load_data(path, 'background_charge.txt')

    # Plot data
    plt.plot(solvent[:,0], solvent[:,1], label = 'solvent')
    plt.plot(jellium[:,0], jellium[:,1], label = 'jellium')
    plt.legend(loc='lower left')
    plt.xlabel('$z$ [A]')
    plt.ylabel('$xy$-averaged value')
    #Plot atoms
    plot_atoms(name)

def plot_change_in_potential(name, name2):
    path0 = "results"
    path1 = "sjm_"+ name
    path2 = "sjm_"+ name2
    path1 = os.path.join(path0, path1)
    path2 = os.path.join(path0, path2)
    # read data
    potential1 = load_data(path1, 'potential.txt')
    potential2 = load_data(path2, 'potential.txt')
    # plot potential
    plt.figure()
    line1, = plt.plot(potential1[:,0], potential1[:,1], label= name)
    line2, = plt.plot(potential2[:,0], potential2[:,1], label= name2)
    # plot potential difference on other axis
    ax1=plt.gca()
    ax2 = ax1.twinx()
    line3, = ax2.plot(potential1[:,0], potential2[:,1]-potential1[:,1],'k', label='difference')
    # add labels and legend
    ax1.set_ylabel('Hartree potential')
    ax2.set_ylabel(f"Change in Hartree potential: \n {name2}-{name}")
    ax1.set_xlabel("$z$ [A]")
    lines = [line1, line2, line3]
    labels = [line.get_label() for line in lines]
    ax1.legend(lines, labels) # Unified legend
    plt.tight_layout()
    

if __name__ == "__main__":
    # define data
    name = "IS_0VSHE"
    name2 = "IS_0VSHE_at-0.6VSHE"
    # plot solvation and jellium
    plot_solvation_jellum(name)
    # plot change in potential upon charging
    plot_change_in_potential(name, name2)
    # show the plots
    plt.show()

