import numpy as np
import matplotlib.pyplot as plt
from diffusion import *


def get_pos(ix, dx):
    """Return position for indix ix, assuming x[0]=0.

    Parameters:
    ----------
    ix: [int] index
    dx: [float] grid spacing

    Output:
    ------
    x: [float] position
    """
    return ix*dx

def analyze_limiting_current(ax, conc, dx, time):
    """Plot results of propagation and compute and plot limiting current

    Parameters:
    ----------
    ax: [plt.ax object] for figure to plot in 
    conc: [numpy array] concentration profile
    dx: [float] grid spacing
    color: [hex] definition of plot color
    """
    # Compute current
    j = # Needs correction! 

    # Create x range for plotting of concentration profile
    nx = len(conc)
    indices = np.arange(len(conc))
    xs = get_pos(indices, dx)
    # Plot concentration profile
    ax[0].plot(xs, conc, label=f"$t$={time}")
    ax[0].set_xlabel('$x$')
    ax[0].set_ylabel('concentration')
    ax[0].legend()

    # Plot limiting current over t
    ax[1].scatter(time, j)
    ax[1].set_xlabel('time $t$')
    ax[1].set_ylabel('current')

    # Plot limiting current over 1/sqrt(t)
    ax[2].scatter(1/np.sqrt(time), j)
    ax[2].set_xlabel(r'$\frac{1}{\sqrt{t}}$')
    ax[2].set_ylabel('current')


if __name__ == "__main__":
    # Set Parameters (dimensionless)
    D = 0.1  # Diffusion coefficient
    L = 2  # Length of the domain
    dx = 0.01 #width of a spatial bin
    nx = int(L/dx)+1  # Number of spatial points (+1 as we want a final grid point to the left and right)
    dt = 1E-4

    # Set times at which to output results
    t_outs = [0] + [0.04, 0.16, 0.36]
    # convert this to time steps
    nt_outs = [int(t/dt) for t in t_outs]

    # Set initial concentration
    conc = np.zeros(nx)
    conc[0:nx] = 1.0  # conc 1 throughout the cell
    conc[0] =  # Needs adjustment!!!: set conc 0 at electrode

    # Run and analyze
    fig, ax = plt.subplots(1,3,figsize=(9,3))
    for i in range(1, len(nt_outs)):
        # Propagatin time to next output
        nt = nt_outs[i] - nt_outs[i-1]
        # Propagate
        conc = propagate(conc, D, dt, nt, dx)
        # Analyze
        analyze_limiting_current(ax, conc, dx, t_outs[i])
    plt.tight_layout()
    plt.show()
