import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np

import scipy.stats as stats

from sklearn.metrics import roc_curve, auc


def colored_bar(data, class_map, color_map, attr_name, class_name):
    width = 0.3
    n_cols = len(data[attr_name].unique())
    p = []

    for gr in data.groupby(attr_name):
        y_offset = np.zeros(n_cols)
        attr_val = gr[0]
        class_vals = dict(gr[1].groupby(class_name).size())

        for class_val in class_vals:
            p.append(plt.bar(attr_val, class_vals[class_val], width,
                             color=color_map[class_val], bottom=y_offset, align='center'))
            y_offset = y_offset + class_vals[class_val]

    plt.title(attr_name)
    plt.ylabel('#count')
    plt.legend(p, class_map.keys())


def class_dist_plot(data, color_map, attr_name, class_name, bins=50):
    class_grs = data[[attr_name, class_name]].groupby([class_name])
    class_vals = []
    densities = []
    x_vals = []

    for class_gr in class_grs:
        class_val = class_gr[0]
        input_data = class_gr[1][attr_name].values
        density = stats.gaussian_kde(input_data)
        n, x_val, _ = plt.hist(input_data, bins, color=color_map[class_val], histtype=u'step', density=True)
        plt.plot(x_val, density(x_val, _), color='k')
        plt.title('Bar plot - attribute: ' + attr_name + ', class label: ' + class_val)
        plt.ylabel('Density')
        plt.show()

        class_vals.append(class_val)
        densities.append(density)
        x_vals.append(x_val)

    for i in range(len(class_vals)):
        plt.plot(x_vals[i], densities[i](x_vals[i]), color=color_map[class_vals[i]])

    plt.title('Class distributions - ' + attr_name)
    plt.ylabel('Density')
    plt.legend(class_vals)
    plt.show()


def colored_bar_perc(data, class_map, color_map, attr_name, class_name):
    width = 0.3
    n_cols = len(data[attr_name].unique())
    col_size = list(data.groupby([attr_name]).size())
    p = []
    jj = 0

    for gr in data.groupby(attr_name):
        y_offset = np.zeros(n_cols)
        attr_val = gr[0]
        class_vals = dict(gr[1].groupby(class_name).size())

        for class_val in class_vals:
            val = 100 * class_vals[class_val] / col_size[jj]
            p.append(plt.bar(attr_val, val, width,
                             color=color_map[class_val], bottom=y_offset, align='center'))
            y_offset = y_offset + val

        jj += 1

    plt.title(attr_name)
    plt.ylabel('Percentage')
    plt.xlabel('Values')
    plt.legend(p, class_map.keys())


def colored_hist(data, class_map, color_map, attr_name, class_name, n_bins=10):
    plot_data = []

    for class_val in class_map:
        plot_data.append(data[data[class_name] == class_val][attr_name].values)

    plt.hist(plot_data, n_bins, histtype='bar', stacked=True, color=color_map.values())

    plt.title(attr_name)
    plt.ylabel('#count')
    plt.legend(class_map.keys())


def colored_scatter(data, attr_name_x, attr_name_y, colors):
    plt.scatter(data[attr_name_x], data[attr_name_y], color=colors, s=60, marker='x')
    plt.xlabel(attr_name_x)
    plt.ylabel(attr_name_y)


def heat_map(matrix, x_labels, y_labels, title):
    fig, ax = plt.subplots()
    ax.imshow(matrix, cmap=cm.Oranges)

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(x_labels)))
    ax.set_yticks(np.arange(len(y_labels)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(y_labels)):
        for j in range(len(x_labels)):
            ax.text(j, i, matrix[i, j], ha="center", va="center", color="k")

    ax.set_title(title)
    fig.tight_layout()


def plot_roc(x_train, y_train, x_validation, y_validation, class_map, models):
    class_assingments = [class_map[class_val] for class_val in y_validation]
    one_hot_encoding = np.zeros((len(y_validation), len(class_map)))
    one_hot_encoding[range(len(class_assingments)), class_assingments] = 1

    probs = []

    for _, model in models:
        model.fit(x_train, y_train)
        probs.append(model.predict_proba(x_validation))

    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    lw = 0.5

    for i in range(len(class_map)):
        plt.figure()
        m = 0

        for name, _ in models:
            fpr[i], tpr[i], _ = roc_curve(one_hot_encoding[:, i], probs[m][:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

            plt.plot(fpr[i], tpr[i], lw=lw, label=name + ' (area = %0.2f)' % roc_auc[i])
            m += 1

        plt.plot([0, 1], [0, 1], lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.0])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Class = ' + list(class_map.keys())[i])
        plt.legend(loc="lower right")

        plt.show()