
# Required python packages
import pandas as pd
import ternary
import matplotlib.pyplot as plt
from matplotlib import gridspec

# Modify only this section to create plots
plottingtype = 'Minerals'                                        # Can be Minerals, Elements, Aqueous, Solutions, Isotopes, fO2, pH
var_to_plot = 'ANTIGORITE prod.'                                 # Name of the variable to plot
logzi_to_plot = -1                                               # Reaction progress to plot, -1 is the last one
path = '/Research_data/Data/5.6Ma_1GPa_500C/Serpentinites/FR1'   # Path to the folder




# Beginning of the plotting section
if plottingtype == 'Minerals':
    data = pd.read_csv(path + '/Mineral_modes.csv')
elif plottingtype == 'Elements':
    data = pd.read_csv(path + '/Dissolved_elements.csv')
elif plottingtype == 'Aqueous':
    data = pd.read_csv(path + '/Aqueous_species_conc.csv')
elif plottingtype == 'Solutions':
    data = pd.read_csv(path + '/Solid_solutions.csv')
elif plottingtype == 'Isotopes':
    data = pd.read_csv(path + '/Isotopic_Data.csv')
elif plottingtype == 'fO2':
    data = pd.read_csv(path + '/fO2.csv')
elif plottingtype == 'pH':
    data = pd.read_csv(path + '/pH.csv')

# Take the data for the variable to plot as a list of values
to_plot = data[[var_to_plot, 'Run #']]

# Create the values to plot
# Reads the .csv file to link run numbers and actual ternary proportions
df = pd.read_csv(path + '/List_proportions.csv')
name_A = df.columns[1]
name_B = df.columns[2]
name_C = df.columns[3]
A = df[name_A]
B = df[name_B]
C = df[name_C]
points = []
for i in range(len(A)):
    points.append([A[i], B[i], C[i]])
runs = df['Run #']
val_to_plot = []
for i in range(len(runs)):
    run = to_plot.loc[to_plot['Run #'].values == runs.iloc[i]]
    if plottingtype == 'Elements':
        val_to_plot.append(10**(run[var_to_plot].iloc[logzi_to_plot]))
    elif plottingtype == 'Aqueous':
        val_to_plot.append(10**(run[var_to_plot].iloc[logzi_to_plot]))
    else:
        val_to_plot.append(run[var_to_plot].iloc[logzi_to_plot])

# Boundary and Gridlines
scale = 100
plt.figure(figsize=[5.05,3.9])
gs = gridspec.GridSpec(1, 1)
ax = plt.subplot(gs[0, 0])
figure, tax = ternary.figure(ax=ax, scale=scale)

# Draw Boundary and Gridlines
tax.boundary(linewidth=2.0)
tax.gridlines(color="black", multiple=5)

# Set ticks
tax.ticks(axis='lbr', multiple=10, linewidth=1, offset=0.025)
tax.get_axes().axis('off')

# Plot peridotite composition field
p1 = (0, 100, 0)
p2 = (0, 60, 40)
tax.line(p1, p2, linewidth=4., color='white', linestyle="-")
p3 = (0, 60, 40)
p4 = (22, 47, 31)
tax.line(p3, p4, linewidth=4., color='white', linestyle="-")
p5 = (22, 47, 31)
p6 = (22, 59, 19)
tax.line(p5, p6, linewidth=4., color='white', linestyle="-")
p7 = (22, 59, 19)
p8 = (10, 70, 20)
tax.line(p7, p8, linewidth=4., color='white', linestyle="-")
p9 = (10, 70, 20)
p10 = (10, 90, 0)
tax.line(p9, p10, linewidth=4., color='white', linestyle="-")
p11 = (10, 90, 0)
p12 = (0, 100, 0)
tax.line(p10, p12, linewidth=4., color='white', linestyle="-")

cb_kwargs = {"shrink" : 1,
            "orientation" : "vertical",
            "location" : "left",
            "fraction" : 0.1,
            "pad" : 0.05,
            "aspect" : 30}

# Plots points
if plottingtype == 'Minerals':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel=var_to_plot + ' (vol%)', c=val_to_plot, cmap=plt.cm.viridis, cb_kwargs=cb_kwargs)
elif plottingtype == 'Elements':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel='Dissolved ' + var_to_plot + ' (mol/kg)', c=val_to_plot, cmap=plt.cm.viridis)
elif plottingtype == 'Aqueous':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel=var_to_plot + ' (mol/kg)', c=val_to_plot, cmap=plt.cm.viridis, cb_kwargs=cb_kwargs)
elif plottingtype == 'Solutions':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel=var_to_plot, c=val_to_plot, cmap=plt.cm.viridis)
elif plottingtype == 'Isotopes':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel=u'$\delta $' + var_to_plot + '$(‰)$', c=val_to_plot, cmap=plt.cm.viridis)
elif plottingtype == 'fO2':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel=u'log $f_{O_2}$', c=val_to_plot, cmap=plt.cm.viridis)
elif plottingtype == 'pH':
    tax.scatter(points, vmin=min(val_to_plot), vmax=max(val_to_plot),
            colormap=plt.cm.viridis, colorbar=True, cbarlabel='pH', c=val_to_plot, cmap=plt.cm.viridis)
plt.text(47, 95, 'Ol', fontsize=12)
plt.text(-10, -12, 'Opx', fontsize=12)
plt.text(102, -12, 'Cpx', fontsize=12)

figure.tight_layout()

ternary.plt.show()