# #### Drug analysis #### #

# Import modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import model_selection
from sklearn import svm
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

import plot_utility as pu


# Read csv file
print('Read csv file')
data = pd.read_csv('data/drug.csv', na_values='?')
print(type(data))

# #### Data Undestanding - Data exploratory analysis #### #
print()
print('Data Undestanding - Data exploratory analysis')

# Data dimensions (n_rows x n_columns)
print()
print('Data dimension:', data.shape)

# Shape is a tuple object
print()
print('Type of shape', type(data.shape))

# List the attributes
print()
print('Attribute list')
attributes = list(data.columns)
print(attributes)

# Print the data
print()
print('Print the data')
print(data)

# Print only the first 10 records
print()
print('Print top 10 tuples')
print(data.head(10))

# Get attribute types
print()
print('Data info')
print(data.info())

# Convert wrong attribute types
print()
data['Sex'] = data['Sex'].astype('category')
data['BP'] = data['BP'].astype('category')
data['Cholesterol'] = data['Cholesterol'].astype('category')
data['Drug'] = data['Drug'].astype('category')
print(data.info())

# Get general statistics
print()
print('Statistics')
print(data.describe(include='all'))

# Building histograms for numerical attributes
print()
print('Histograms')
data.hist(figsize=(13, 13))
plt.show()

# Bar plots for categorical attributes
print()
print('Histograms')
groups = dict(data.groupby('Sex').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('Sex')
plt.show()

groups = dict(data.groupby('BP').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('BP')
plt.show()

groups = dict(data.groupby('Cholesterol').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('Cholesterol')
plt.show()

# In grid
# subplot(nrow, ncols)
plt.subplots(2, 2, figsize=(10, 10))

# subplot(nrow, ncols, index)
plt.subplot(2, 2, 1)
groups = dict(data.groupby('Sex').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('Sex')

plt.subplot(2, 2, 2)
groups = dict(data.groupby('BP').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('BP')

plt.subplot(2, 2, 3)
groups = dict(data.groupby('Cholesterol').size())
plt.bar(groups.keys(), groups.values(), align='center')
plt.title('Cholesterol')

plt.subplot(2, 2, 4)
class_values = dict(data.groupby('Drug').size())
plt.bar(class_values.keys(), class_values.values(), align='center')
plt.title('Drug')

plt.show()

# Class distribution
print()
print('Class distribution')
i = 0
colors = ['red', 'green', 'orange', 'blue', 'purple']
class_map = dict()
color_map = dict()


for class_value in class_values:
    color_map[class_value] = colors[i]
    class_map[class_value] = i
    i += 1

plt.bar(class_values.keys(), class_values.values(), align='center', color=color_map.values())
plt.show()

# Color bar plots according to the class colors
print()
print('Color bar plots according to the class colors')
plt.subplots(1, 3, figsize=(15, 6))

plt.subplot(1, 3, 1)
pu.colored_bar(data, class_map, color_map, 'Sex', 'Drug')

plt.subplot(1, 3, 2)
pu.colored_bar(data, class_map, color_map, 'BP', 'Drug')

plt.subplot(1, 3, 3)
pu.colored_bar(data, class_map, color_map, 'Cholesterol', 'Drug')

plt.show()

# Color histograms according to the class colors
print()
print('Color histograms according to the class colors')
plt.subplots(1, 3, figsize=(15, 6))

plt.subplot(1, 3, 1)
pu.colored_hist(data, class_map, color_map, 'Instance_number', 'Drug', n_bins=10)

plt.subplot(1, 3, 2)
pu.colored_hist(data, class_map, color_map, 'ID', 'Drug', n_bins=10)

plt.subplot(1, 3, 3)
pu.colored_hist(data, class_map, color_map, 'Age', 'Drug', n_bins=10)

plt.show()

plt.subplots(1, 2, figsize=(15, 6))

plt.subplot(1, 2, 1)
pu.colored_hist(data, class_map, color_map, 'Na', 'Drug', n_bins=10)

plt.subplot(1, 2, 2)
pu.colored_hist(data, class_map, color_map, 'K', 'Drug', n_bins=10)

plt.show()

# box and whisker plots (numerical attributes)
print()
print('Box plots')
data.plot(kind='box', subplots=True, sharex=False, sharey=False, figsize=(15, 5))
plt.show()
print('No outliers')

# Scatter plot matrix (numerical attributes)
print()
print('Scatter plot matrix')
pd.plotting.scatter_matrix(data, figsize=(10, 10))
plt.show()

print()
print('Colored Scatter plots')
colors = data['Drug'].apply(lambda class_val: color_map[class_val])

numeric_attributes = list(data.select_dtypes(include=['float64', 'int64']).columns)
n_numeric_attributes = len(numeric_attributes)

for i in range(n_numeric_attributes):
    for j in range(i + 1, n_numeric_attributes):
        plt.figure(figsize=(5, 5))
        pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)
        plt.show()

print()
print('In grid')

for i in range(n_numeric_attributes):
    k = 1
    plt.subplots(1,  n_numeric_attributes, figsize=(15, 5))

    for j in range(n_numeric_attributes):
        plt.subplot(1, n_numeric_attributes, k)

        if i == j:
            pu.colored_hist(data, class_map, color_map, numeric_attributes[i], 'Drug', n_bins=10)
        else:
            pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)

        k += 1

    plt.show()

# Focus on this scatter plot
print()
print('Focus on Na-K scatter plot')
plt.figure(figsize=(6, 6))
pu.colored_scatter(data, 'Na', 'K', colors)
plt.show()

print()
print('A hyperplane can separate the drugY (purple) class')
x = [0.5, 0.9]
y = [0.033, 0.061]

plt.figure(figsize=(6, 6))
pu.colored_scatter(data, 'Na', 'K', colors)
plt.plot(x, y, linewidth=3, color='black')
plt.show()

# #### Data preparation #### #

# Remove all the attributes with too high (typically IDs)
# or too low variability (column that exhibit the same value for all the records)
print()
print('Remove useless attributes')
data = data[['Sex', 'BP', 'Cholesterol', 'Age', 'Na', 'K', 'Drug']]
attributes = list(data.columns)
n_attributes = len(attributes)
print(data.head())

# Highlight missing values
print()
print('Highlight missing values')
print(data.info())
print()
print('Missing for each attribute')
print(data.isna().sum())

# Embed the hyperplane of separation information
# Compute the hyperplane
print()
print('Embed the hyperplane of separation information')

# Create model
print('Create model')
svm_model = svm.SVC(kernel='linear', C=1000)

# Prepare data projection
print('Prepare data projection')
x = data[['Na', 'K']].values
y = np.zeros(len(data)) - 1
y[data['Drug'] == 'drugY'] = 1

# Replace missing values
print('Replace missing values')
np.nan_to_num(x, copy=False)

# Train the model
svm_model.fit(x, y)

# get the separating hyperplane
w = svm_model.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(0.5, 0.9)
yy = a * xx - (svm_model.intercept_[0]) / w[1]

# plot
plt.figure(figsize=(10, 10))
pu.colored_scatter(data, 'Na', 'K', colors)
plt.plot(xx, yy, linewidth=3, color='black')
plt.show()

# Create a new attribute
print('Create a new attribute')
new_attr_val = np.array(['Under_threshold'] * len(data))
new_attr_val[data['K'] > a * data['Na'] - (svm_model.intercept_[0]) / w[1]] = 'Above_threshold'
new_attr_val[data['K'].isna() | data['Na'].isna()] = 'Unknown'
data['Hyper'] = new_attr_val
data['Hyper'] = data['Hyper'].astype('category')
print(data.head(10))
print()
print(data.info())
plt.figure(figsize=(6, 6))
pu.colored_bar(data, class_map, color_map, 'Hyper', 'Drug')
plt.show()

# #### Modeling #### #
print()
print('Modeling')

# Binarize the dataset
print()
print('Binarize the dataset')
data2 = pd.get_dummies(data, columns=["Sex", "BP", "Cholesterol", 'Hyper'])
new_attr_list = list(data2.columns)
new_attr_list.remove('Drug')
data2 = data2[new_attr_list]
print(data2.head(10))

# Split-out validation dataset
print()
print('Split-out validation dataset')
X = np.array(data2.values)

# Replace missing values
X[np.isnan(X)] = 0

Y = np.array(data['Drug'].values)
validation_size = 0.20
seed = 121

X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(
    X, Y, test_size=validation_size, random_state=seed)

# Create the models
print()
print('Create the models')
models = []
models.append(('C45', DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)))
models.append(('CART', DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)))

# evaluate each model in turn
results = []
names = []

for name, model in models:
    kfold = model_selection.KFold(n_splits=10, random_state=seed)
    cv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring='accuracy')
    results.append(cv_results)
    names.append(name)
    msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
    print(msg)

# #### Evaluation #### #
print()
print('Evaluation')

# Model comparison (Cross validation)
print()
print('Model comparison')
fig = plt.figure()
fig.suptitle('Model Comparison')
plt.boxplot(results)
plt.xticks(range(1, len(names) + 1), names)
plt.show()

# Do predictions on test set
print()
print('Do predictions on test set')
model = DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)
model.fit(X_train, Y_train)
predictions = model.predict(X_validation)
print('Accuracy:', accuracy_score(Y_validation, predictions))
print()
print()
print('Confusion matrix:')
print(confusion_matrix(Y_validation, predictions))
print()
print()
print('Classification report')
print(classification_report(Y_validation, predictions))
