welcome: please sign in

Cerca

Link Dipartimentali

Link Esterni

Allegato "CaseStudy1_Drug.py"

Scarica

   1 # #### Drug analysis #### #
   2 
   3 # Import modules
   4 import matplotlib.pyplot as plt
   5 import numpy as np
   6 import pandas as pd
   7 from sklearn import model_selection
   8 from sklearn import svm
   9 from sklearn.tree import DecisionTreeClassifier
  10 from sklearn.metrics import classification_report
  11 from sklearn.metrics import confusion_matrix
  12 from sklearn.metrics import accuracy_score
  13 
  14 import plot_utility as pu
  15 
  16 
  17 # Read csv file
  18 print('Read csv file')
  19 data = pd.read_csv('data/drug.csv', na_values='?')
  20 print(type(data))
  21 
  22 # #### Data Undestanding - Data exploratory analysis #### #
  23 print()
  24 print('Data Undestanding - Data exploratory analysis')
  25 
  26 # Data dimensions (n_rows x n_columns)
  27 print()
  28 print('Data dimension:', data.shape)
  29 
  30 # Shape is a tuple object
  31 print()
  32 print('Type of shape', type(data.shape))
  33 
  34 # List the attributes
  35 print()
  36 print('Attribute list')
  37 attributes = list(data.columns)
  38 print(attributes)
  39 
  40 # Print the data
  41 print()
  42 print('Print the data')
  43 print(data)
  44 
  45 # Print only the first 10 records
  46 print()
  47 print('Print top 10 tuples')
  48 print(data.head(10))
  49 
  50 # Get attribute types
  51 print()
  52 print('Data info')
  53 print(data.info())
  54 
  55 # Convert wrong attribute types
  56 print()
  57 data['Sex'] = data['Sex'].astype('category')
  58 data['BP'] = data['BP'].astype('category')
  59 data['Cholesterol'] = data['Cholesterol'].astype('category')
  60 data['Drug'] = data['Drug'].astype('category')
  61 print(data.info())
  62 
  63 # Get general statistics
  64 print()
  65 print('Statistics')
  66 print(data.describe(include='all'))
  67 
  68 # Building histograms for numerical attributes
  69 print()
  70 print('Histograms')
  71 data.hist(figsize=(13, 13))
  72 plt.show()
  73 
  74 # Bar plots for categorical attributes
  75 print()
  76 print('Histograms')
  77 groups = dict(data.groupby('Sex').size())
  78 plt.bar(groups.keys(), groups.values(), align='center')
  79 plt.title('Sex')
  80 plt.show()
  81 
  82 groups = dict(data.groupby('BP').size())
  83 plt.bar(groups.keys(), groups.values(), align='center')
  84 plt.title('BP')
  85 plt.show()
  86 
  87 groups = dict(data.groupby('Cholesterol').size())
  88 plt.bar(groups.keys(), groups.values(), align='center')
  89 plt.title('Cholesterol')
  90 plt.show()
  91 
  92 # In grid
  93 # subplot(nrow, ncols)
  94 plt.subplots(2, 2, figsize=(10, 10))
  95 
  96 # subplot(nrow, ncols, index)
  97 plt.subplot(2, 2, 1)
  98 groups = dict(data.groupby('Sex').size())
  99 plt.bar(groups.keys(), groups.values(), align='center')
 100 plt.title('Sex')
 101 
 102 plt.subplot(2, 2, 2)
 103 groups = dict(data.groupby('BP').size())
 104 plt.bar(groups.keys(), groups.values(), align='center')
 105 plt.title('BP')
 106 
 107 plt.subplot(2, 2, 3)
 108 groups = dict(data.groupby('Cholesterol').size())
 109 plt.bar(groups.keys(), groups.values(), align='center')
 110 plt.title('Cholesterol')
 111 
 112 plt.subplot(2, 2, 4)
 113 class_values = dict(data.groupby('Drug').size())
 114 plt.bar(class_values.keys(), class_values.values(), align='center')
 115 plt.title('Drug')
 116 
 117 plt.show()
 118 
 119 # Class distribution
 120 print()
 121 print('Class distribution')
 122 i = 0
 123 colors = ['red', 'green', 'orange', 'blue', 'purple']
 124 class_map = dict()
 125 color_map = dict()
 126 
 127 
 128 for class_value in class_values:
 129     color_map[class_value] = colors[i]
 130     class_map[class_value] = i
 131     i += 1
 132 
 133 plt.bar(class_values.keys(), class_values.values(), align='center', color=color_map.values())
 134 plt.show()
 135 
 136 # Color bar plots according to the class colors
 137 print()
 138 print('Color bar plots according to the class colors')
 139 plt.subplots(1, 3, figsize=(15, 6))
 140 
 141 plt.subplot(1, 3, 1)
 142 pu.colored_bar(data, class_map, color_map, 'Sex', 'Drug')
 143 
 144 plt.subplot(1, 3, 2)
 145 pu.colored_bar(data, class_map, color_map, 'BP', 'Drug')
 146 
 147 plt.subplot(1, 3, 3)
 148 pu.colored_bar(data, class_map, color_map, 'Cholesterol', 'Drug')
 149 
 150 plt.show()
 151 
 152 # Color histograms according to the class colors
 153 print()
 154 print('Color histograms according to the class colors')
 155 plt.subplots(1, 3, figsize=(15, 6))
 156 
 157 plt.subplot(1, 3, 1)
 158 pu.colored_hist(data, class_map, color_map, 'Instance_number', 'Drug', n_bins=10)
 159 
 160 plt.subplot(1, 3, 2)
 161 pu.colored_hist(data, class_map, color_map, 'ID', 'Drug', n_bins=10)
 162 
 163 plt.subplot(1, 3, 3)
 164 pu.colored_hist(data, class_map, color_map, 'Age', 'Drug', n_bins=10)
 165 
 166 plt.show()
 167 
 168 plt.subplots(1, 2, figsize=(15, 6))
 169 
 170 plt.subplot(1, 2, 1)
 171 pu.colored_hist(data, class_map, color_map, 'Na', 'Drug', n_bins=10)
 172 
 173 plt.subplot(1, 2, 2)
 174 pu.colored_hist(data, class_map, color_map, 'K', 'Drug', n_bins=10)
 175 
 176 plt.show()
 177 
 178 # box and whisker plots (numerical attributes)
 179 print()
 180 print('Box plots')
 181 data.plot(kind='box', subplots=True, sharex=False, sharey=False, figsize=(15, 5))
 182 plt.show()
 183 print('No outliers')
 184 
 185 # Scatter plot matrix (numerical attributes)
 186 print()
 187 print('Scatter plot matrix')
 188 pd.plotting.scatter_matrix(data, figsize=(10, 10))
 189 plt.show()
 190 
 191 print()
 192 print('Colored Scatter plots')
 193 colors = data['Drug'].apply(lambda class_val: color_map[class_val])
 194 
 195 numeric_attributes = list(data.select_dtypes(include=['float64', 'int64']).columns)
 196 n_numeric_attributes = len(numeric_attributes)
 197 
 198 for i in range(n_numeric_attributes):
 199     for j in range(i + 1, n_numeric_attributes):
 200         plt.figure(figsize=(5, 5))
 201         pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)
 202         plt.show()
 203 
 204 print()
 205 print('In grid')
 206 
 207 for i in range(n_numeric_attributes):
 208     k = 1
 209     plt.subplots(1,  n_numeric_attributes, figsize=(15, 5))
 210 
 211     for j in range(n_numeric_attributes):
 212         plt.subplot(1, n_numeric_attributes, k)
 213 
 214         if i == j:
 215             pu.colored_hist(data, class_map, color_map, numeric_attributes[i], 'Drug', n_bins=10)
 216         else:
 217             pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)
 218 
 219         k += 1
 220 
 221     plt.show()
 222 
 223 # Focus on this scatter plot
 224 print()
 225 print('Focus on Na-K scatter plot')
 226 plt.figure(figsize=(6, 6))
 227 pu.colored_scatter(data, 'Na', 'K', colors)
 228 plt.show()
 229 
 230 print()
 231 print('A hyperplane can separate the drugY (purple) class')
 232 x = [0.5, 0.9]
 233 y = [0.033, 0.061]
 234 
 235 plt.figure(figsize=(6, 6))
 236 pu.colored_scatter(data, 'Na', 'K', colors)
 237 plt.plot(x, y, linewidth=3, color='black')
 238 plt.show()
 239 
 240 # #### Data preparation #### #
 241 
 242 # Remove all the attributes with too high (typically IDs)
 243 # or too low variability (column that exhibit the same value for all the records)
 244 print()
 245 print('Remove useless attributes')
 246 data = data[['Sex', 'BP', 'Cholesterol', 'Age', 'Na', 'K', 'Drug']]
 247 attributes = list(data.columns)
 248 n_attributes = len(attributes)
 249 print(data.head())
 250 
 251 # Highlight missing values
 252 print()
 253 print('Highlight missing values')
 254 print(data.info())
 255 print()
 256 print('Missing for each attribute')
 257 print(data.isna().sum())
 258 
 259 # Embed the hyperplane of separation information
 260 # Compute the hyperplane
 261 print()
 262 print('Embed the hyperplane of separation information')
 263 
 264 # Create model
 265 print('Create model')
 266 svm_model = svm.SVC(kernel='linear', C=1000)
 267 
 268 # Prepare data projection
 269 print('Prepare data projection')
 270 x = data[['Na', 'K']].values
 271 y = np.zeros(len(data)) - 1
 272 y[data['Drug'] == 'drugY'] = 1
 273 
 274 # Replace missing values
 275 print('Replace missing values')
 276 np.nan_to_num(x, copy=False)
 277 
 278 # Train the model
 279 svm_model.fit(x, y)
 280 
 281 # get the separating hyperplane
 282 w = svm_model.coef_[0]
 283 a = -w[0] / w[1]
 284 xx = np.linspace(0.5, 0.9)
 285 yy = a * xx - (svm_model.intercept_[0]) / w[1]
 286 
 287 # plot
 288 plt.figure(figsize=(10, 10))
 289 pu.colored_scatter(data, 'Na', 'K', colors)
 290 plt.plot(xx, yy, linewidth=3, color='black')
 291 plt.show()
 292 
 293 # Create a new attribute
 294 print('Create a new attribute')
 295 new_attr_val = np.array(['Under_threshold'] * len(data))
 296 new_attr_val[data['K'] > a * data['Na'] - (svm_model.intercept_[0]) / w[1]] = 'Above_threshold'
 297 new_attr_val[data['K'].isna() | data['Na'].isna()] = 'Unknown'
 298 data['Hyper'] = new_attr_val
 299 data['Hyper'] = data['Hyper'].astype('category')
 300 print(data.head(10))
 301 print()
 302 print(data.info())
 303 plt.figure(figsize=(6, 6))
 304 pu.colored_bar(data, class_map, color_map, 'Hyper', 'Drug')
 305 plt.show()
 306 
 307 # #### Modeling #### #
 308 print()
 309 print('Modeling')
 310 
 311 # Binarize the dataset
 312 print()
 313 print('Binarize the dataset')
 314 data2 = pd.get_dummies(data, columns=["Sex", "BP", "Cholesterol", 'Hyper'])
 315 new_attr_list = list(data2.columns)
 316 new_attr_list.remove('Drug')
 317 data2 = data2[new_attr_list]
 318 print(data2.head(10))
 319 
 320 # Split-out validation dataset
 321 print()
 322 print('Split-out validation dataset')
 323 X = np.array(data2.values)
 324 
 325 # Replace missing values
 326 X[np.isnan(X)] = 0
 327 
 328 Y = np.array(data['Drug'].values)
 329 validation_size = 0.20
 330 seed = 121
 331 
 332 X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(
 333     X, Y, test_size=validation_size, random_state=seed)
 334 
 335 # Create the models
 336 print()
 337 print('Create the models')
 338 models = []
 339 models.append(('C45', DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)))
 340 models.append(('CART', DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)))
 341 
 342 # evaluate each model in turn
 343 results = []
 344 names = []
 345 
 346 for name, model in models:
 347     kfold = model_selection.KFold(n_splits=10, random_state=seed)
 348     cv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring='accuracy')
 349     results.append(cv_results)
 350     names.append(name)
 351     msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
 352     print(msg)
 353 
 354 # #### Evaluation #### #
 355 print()
 356 print('Evaluation')
 357 
 358 # Model comparison (Cross validation)
 359 print()
 360 print('Model comparison')
 361 fig = plt.figure()
 362 fig.suptitle('Model Comparison')
 363 plt.boxplot(results)
 364 plt.xticks(range(1, len(names) + 1), names)
 365 plt.show()
 366 
 367 # Do predictions on test set
 368 print()
 369 print('Do predictions on test set')
 370 model = DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)
 371 model.fit(X_train, Y_train)
 372 predictions = model.predict(X_validation)
 373 print('Accuracy:', accuracy_score(Y_validation, predictions))
 374 print()
 375 print()
 376 print('Confusion matrix:')
 377 print(confusion_matrix(Y_validation, predictions))
 378 print()
 379 print()
 380 print('Classification report')
 381 print(classification_report(Y_validation, predictions))

Allegati

Non รจ consentito inserire allegati su questa pagina.