import pandas as pd
import os
import matplotlib.pyplot as plt
import cartopy
import cartopy.io.shapereader as shpreader
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from matplotlib.colors import Normalize
countries = pd.read_csv("./countrycodes.csv")
african_countries = list(countries.loc[countries["region"]=="Africa"]["alpha-3"])
sahel_countries=["SEN", "MRT", "MLI", "BFA", "NER", "NGA", "TCD", "CAF",
                "SDN", "ERI", "ETH"]
mappings = pd.read_csv("./config_never_change.csv")
indicators_to_id = mappings[["indicator","ID"]].dropna(subset=["ID"]).set_index("indicator").to_dict()["ID"]
indicators = list(indicators_to_id)
def plot_ndvariables_africa(indicator):
    plt.rcParams['figure.figsize'] = [20, 8]
    num_simulations = 1000
    year = '2018'
    id_used = indicators_to_id[indicator].lower()
    data = pd.read_csv(os.path.join("indicators", id_used, "score.csv"), index_col=0).dropna()

    # calculate simulations
    mean_sim = []
    num_african = len(african_countries)
    for sim_num in range(num_simulations):
        sub = data.sample(num_african)
        mean = sub[year].mean()
        mean_sim.append(mean)
    mean_african = data.reindex(african_countries)[year].mean()
    mean_sahel = data.reindex(sahel_countries)[year].mean()
    
    # plot simulations
    plt.hist(mean_sim, label="Global distribution over 1000 simulations")
    plt.axvline(x=mean_african, linewidth=4, color='r', label="Pan-African mean")
    plt.axvline(x=mean_sahel, linewidth=4, color='b', label="Sahel mean")
    plt.legend()
    plt.title("(2018) " + indicator.title())
    plt.show()
    
    # plot africa
    crs = cartopy.crs.PlateCarree()
    ax = plt.axes(projection=cartopy.crs.PlateCarree())
    ax.add_feature(cartopy.feature.LAND)
    ax.add_feature(cartopy.feature.OCEAN)
    ax.add_feature(cartopy.feature.COASTLINE)
    ax.add_feature(cartopy.feature.BORDERS, linestyle=':')
    ax.set_extent([-20, 55, -33, 33])
    shp = shpreader.natural_earth(resolution='10m',category='cultural',
                                    name='admin_0_countries')
    reader = shpreader.Reader(shp)
    z = data.reindex(african_countries)[year].dropna()
    norm = Normalize(vmin=min(z), vmax = max(z))
    cmap = plt.cm.get_cmap("Reds")
    for country in shpreader.Reader(shp).records():
        name = country.attributes['ISO_A3']
        if name not in z: continue
        val = z[name]
        ax.add_geometries([country.geometry], crs,
                    facecolor=cmap(norm(val)))
    sm = plt.cm.ScalarMappable(cmap=cmap,norm=norm)
    sm._A = []
    plt.colorbar(sm,ax=ax)
    plt.title(indicator.title() + " (Pan-Africa)")
    plt.show()
    
    # plot sahel
    ax = plt.axes(projection=cartopy.crs.PlateCarree())
    ax.add_feature(cartopy.feature.LAND)
    ax.add_feature(cartopy.feature.OCEAN)
    ax.add_feature(cartopy.feature.COASTLINE)
    ax.add_feature(cartopy.feature.BORDERS, linestyle=':')
    ax.set_extent([-20, 45, 5, 25])
    shp = shpreader.natural_earth(resolution='10m',category='cultural',
                                    name='admin_0_countries')
    reader = shpreader.Reader(shp)
    z = data.reindex(african_countries)[year].dropna()
    norm = Normalize(vmin=min(z), vmax = max(z))
    cmap = plt.cm.get_cmap("Reds")
    for country in shpreader.Reader(shp).records():
        name = country.attributes['ISO_A3']
        if name not in z: continue
        val = z[name]
        ax.add_geometries([country.geometry], crs,
                    facecolor=cmap(norm(val)))
    sm = plt.cm.ScalarMappable(cmap=cmap,norm=norm)
    sm._A = []
    plt.colorbar(sm,ax=ax)
    plt.title(indicator.title() + " (Sahel)")
    plt.show()
interact(plot_ndvariables_africa, indicator=list(indicators_to_id))
<function __main__.plot_ndvariables_africa>