Allegato "CaseStudy1_Drug.py"
Scarica 1 # #### Drug analysis #### #
2
3 # Import modules
4 import matplotlib.pyplot as plt
5 import numpy as np
6 import pandas as pd
7 from sklearn import model_selection
8 from sklearn import svm
9 from sklearn.tree import DecisionTreeClassifier
10 from sklearn.metrics import classification_report
11 from sklearn.metrics import confusion_matrix
12 from sklearn.metrics import accuracy_score
13
14 import plot_utility as pu
15
16
17 # Read csv file
18 print('Read csv file')
19 data = pd.read_csv('data/drug.csv', na_values='?')
20 print(type(data))
21
22 # #### Data Undestanding - Data exploratory analysis #### #
23 print()
24 print('Data Undestanding - Data exploratory analysis')
25
26 # Data dimensions (n_rows x n_columns)
27 print()
28 print('Data dimension:', data.shape)
29
30 # Shape is a tuple object
31 print()
32 print('Type of shape', type(data.shape))
33
34 # List the attributes
35 print()
36 print('Attribute list')
37 attributes = list(data.columns)
38 print(attributes)
39
40 # Print the data
41 print()
42 print('Print the data')
43 print(data)
44
45 # Print only the first 10 records
46 print()
47 print('Print top 10 tuples')
48 print(data.head(10))
49
50 # Get attribute types
51 print()
52 print('Data info')
53 print(data.info())
54
55 # Convert wrong attribute types
56 print()
57 data['Sex'] = data['Sex'].astype('category')
58 data['BP'] = data['BP'].astype('category')
59 data['Cholesterol'] = data['Cholesterol'].astype('category')
60 data['Drug'] = data['Drug'].astype('category')
61 print(data.info())
62
63 # Get general statistics
64 print()
65 print('Statistics')
66 print(data.describe(include='all'))
67
68 # Building histograms for numerical attributes
69 print()
70 print('Histograms')
71 data.hist(figsize=(13, 13))
72 plt.show()
73
74 # Bar plots for categorical attributes
75 print()
76 print('Histograms')
77 groups = dict(data.groupby('Sex').size())
78 plt.bar(groups.keys(), groups.values(), align='center')
79 plt.title('Sex')
80 plt.show()
81
82 groups = dict(data.groupby('BP').size())
83 plt.bar(groups.keys(), groups.values(), align='center')
84 plt.title('BP')
85 plt.show()
86
87 groups = dict(data.groupby('Cholesterol').size())
88 plt.bar(groups.keys(), groups.values(), align='center')
89 plt.title('Cholesterol')
90 plt.show()
91
92 # In grid
93 # subplot(nrow, ncols)
94 plt.subplots(2, 2, figsize=(10, 10))
95
96 # subplot(nrow, ncols, index)
97 plt.subplot(2, 2, 1)
98 groups = dict(data.groupby('Sex').size())
99 plt.bar(groups.keys(), groups.values(), align='center')
100 plt.title('Sex')
101
102 plt.subplot(2, 2, 2)
103 groups = dict(data.groupby('BP').size())
104 plt.bar(groups.keys(), groups.values(), align='center')
105 plt.title('BP')
106
107 plt.subplot(2, 2, 3)
108 groups = dict(data.groupby('Cholesterol').size())
109 plt.bar(groups.keys(), groups.values(), align='center')
110 plt.title('Cholesterol')
111
112 plt.subplot(2, 2, 4)
113 class_values = dict(data.groupby('Drug').size())
114 plt.bar(class_values.keys(), class_values.values(), align='center')
115 plt.title('Drug')
116
117 plt.show()
118
119 # Class distribution
120 print()
121 print('Class distribution')
122 i = 0
123 colors = ['red', 'green', 'orange', 'blue', 'purple']
124 class_map = dict()
125 color_map = dict()
126
127
128 for class_value in class_values:
129 color_map[class_value] = colors[i]
130 class_map[class_value] = i
131 i += 1
132
133 plt.bar(class_values.keys(), class_values.values(), align='center', color=color_map.values())
134 plt.show()
135
136 # Color bar plots according to the class colors
137 print()
138 print('Color bar plots according to the class colors')
139 plt.subplots(1, 3, figsize=(15, 6))
140
141 plt.subplot(1, 3, 1)
142 pu.colored_bar(data, class_map, color_map, 'Sex', 'Drug')
143
144 plt.subplot(1, 3, 2)
145 pu.colored_bar(data, class_map, color_map, 'BP', 'Drug')
146
147 plt.subplot(1, 3, 3)
148 pu.colored_bar(data, class_map, color_map, 'Cholesterol', 'Drug')
149
150 plt.show()
151
152 # Color histograms according to the class colors
153 print()
154 print('Color histograms according to the class colors')
155 plt.subplots(1, 3, figsize=(15, 6))
156
157 plt.subplot(1, 3, 1)
158 pu.colored_hist(data, class_map, color_map, 'Instance_number', 'Drug', n_bins=10)
159
160 plt.subplot(1, 3, 2)
161 pu.colored_hist(data, class_map, color_map, 'ID', 'Drug', n_bins=10)
162
163 plt.subplot(1, 3, 3)
164 pu.colored_hist(data, class_map, color_map, 'Age', 'Drug', n_bins=10)
165
166 plt.show()
167
168 plt.subplots(1, 2, figsize=(15, 6))
169
170 plt.subplot(1, 2, 1)
171 pu.colored_hist(data, class_map, color_map, 'Na', 'Drug', n_bins=10)
172
173 plt.subplot(1, 2, 2)
174 pu.colored_hist(data, class_map, color_map, 'K', 'Drug', n_bins=10)
175
176 plt.show()
177
178 # box and whisker plots (numerical attributes)
179 print()
180 print('Box plots')
181 data.plot(kind='box', subplots=True, sharex=False, sharey=False, figsize=(15, 5))
182 plt.show()
183 print('No outliers')
184
185 # Scatter plot matrix (numerical attributes)
186 print()
187 print('Scatter plot matrix')
188 pd.plotting.scatter_matrix(data, figsize=(10, 10))
189 plt.show()
190
191 print()
192 print('Colored Scatter plots')
193 colors = data['Drug'].apply(lambda class_val: color_map[class_val])
194
195 numeric_attributes = list(data.select_dtypes(include=['float64', 'int64']).columns)
196 n_numeric_attributes = len(numeric_attributes)
197
198 for i in range(n_numeric_attributes):
199 for j in range(i + 1, n_numeric_attributes):
200 plt.figure(figsize=(5, 5))
201 pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)
202 plt.show()
203
204 print()
205 print('In grid')
206
207 for i in range(n_numeric_attributes):
208 k = 1
209 plt.subplots(1, n_numeric_attributes, figsize=(15, 5))
210
211 for j in range(n_numeric_attributes):
212 plt.subplot(1, n_numeric_attributes, k)
213
214 if i == j:
215 pu.colored_hist(data, class_map, color_map, numeric_attributes[i], 'Drug', n_bins=10)
216 else:
217 pu.colored_scatter(data, numeric_attributes[i], numeric_attributes[j], colors)
218
219 k += 1
220
221 plt.show()
222
223 # Focus on this scatter plot
224 print()
225 print('Focus on Na-K scatter plot')
226 plt.figure(figsize=(6, 6))
227 pu.colored_scatter(data, 'Na', 'K', colors)
228 plt.show()
229
230 print()
231 print('A hyperplane can separate the drugY (purple) class')
232 x = [0.5, 0.9]
233 y = [0.033, 0.061]
234
235 plt.figure(figsize=(6, 6))
236 pu.colored_scatter(data, 'Na', 'K', colors)
237 plt.plot(x, y, linewidth=3, color='black')
238 plt.show()
239
240 # #### Data preparation #### #
241
242 # Remove all the attributes with too high (typically IDs)
243 # or too low variability (column that exhibit the same value for all the records)
244 print()
245 print('Remove useless attributes')
246 data = data[['Sex', 'BP', 'Cholesterol', 'Age', 'Na', 'K', 'Drug']]
247 attributes = list(data.columns)
248 n_attributes = len(attributes)
249 print(data.head())
250
251 # Highlight missing values
252 print()
253 print('Highlight missing values')
254 print(data.info())
255 print()
256 print('Missing for each attribute')
257 print(data.isna().sum())
258
259 # Embed the hyperplane of separation information
260 # Compute the hyperplane
261 print()
262 print('Embed the hyperplane of separation information')
263
264 # Create model
265 print('Create model')
266 svm_model = svm.SVC(kernel='linear', C=1000)
267
268 # Prepare data projection
269 print('Prepare data projection')
270 x = data[['Na', 'K']].values
271 y = np.zeros(len(data)) - 1
272 y[data['Drug'] == 'drugY'] = 1
273
274 # Replace missing values
275 print('Replace missing values')
276 np.nan_to_num(x, copy=False)
277
278 # Train the model
279 svm_model.fit(x, y)
280
281 # get the separating hyperplane
282 w = svm_model.coef_[0]
283 a = -w[0] / w[1]
284 xx = np.linspace(0.5, 0.9)
285 yy = a * xx - (svm_model.intercept_[0]) / w[1]
286
287 # plot
288 plt.figure(figsize=(10, 10))
289 pu.colored_scatter(data, 'Na', 'K', colors)
290 plt.plot(xx, yy, linewidth=3, color='black')
291 plt.show()
292
293 # Create a new attribute
294 print('Create a new attribute')
295 new_attr_val = np.array(['Under_threshold'] * len(data))
296 new_attr_val[data['K'] > a * data['Na'] - (svm_model.intercept_[0]) / w[1]] = 'Above_threshold'
297 new_attr_val[data['K'].isna() | data['Na'].isna()] = 'Unknown'
298 data['Hyper'] = new_attr_val
299 data['Hyper'] = data['Hyper'].astype('category')
300 print(data.head(10))
301 print()
302 print(data.info())
303 plt.figure(figsize=(6, 6))
304 pu.colored_bar(data, class_map, color_map, 'Hyper', 'Drug')
305 plt.show()
306
307 # #### Modeling #### #
308 print()
309 print('Modeling')
310
311 # Binarize the dataset
312 print()
313 print('Binarize the dataset')
314 data2 = pd.get_dummies(data, columns=["Sex", "BP", "Cholesterol", 'Hyper'])
315 new_attr_list = list(data2.columns)
316 new_attr_list.remove('Drug')
317 data2 = data2[new_attr_list]
318 print(data2.head(10))
319
320 # Split-out validation dataset
321 print()
322 print('Split-out validation dataset')
323 X = np.array(data2.values)
324
325 # Replace missing values
326 X[np.isnan(X)] = 0
327
328 Y = np.array(data['Drug'].values)
329 validation_size = 0.20
330 seed = 121
331
332 X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(
333 X, Y, test_size=validation_size, random_state=seed)
334
335 # Create the models
336 print()
337 print('Create the models')
338 models = []
339 models.append(('C45', DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)))
340 models.append(('CART', DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)))
341
342 # evaluate each model in turn
343 results = []
344 names = []
345
346 for name, model in models:
347 kfold = model_selection.KFold(n_splits=10, random_state=seed)
348 cv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring='accuracy')
349 results.append(cv_results)
350 names.append(name)
351 msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
352 print(msg)
353
354 # #### Evaluation #### #
355 print()
356 print('Evaluation')
357
358 # Model comparison (Cross validation)
359 print()
360 print('Model comparison')
361 fig = plt.figure()
362 fig.suptitle('Model Comparison')
363 plt.boxplot(results)
364 plt.xticks(range(1, len(names) + 1), names)
365 plt.show()
366
367 # Do predictions on test set
368 print()
369 print('Do predictions on test set')
370 model = DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)
371 model.fit(X_train, Y_train)
372 predictions = model.predict(X_validation)
373 print('Accuracy:', accuracy_score(Y_validation, predictions))
374 print()
375 print()
376 print('Confusion matrix:')
377 print(confusion_matrix(Y_validation, predictions))
378 print()
379 print()
380 print('Classification report')
381 print(classification_report(Y_validation, predictions))
Allegati
Non รจ consentito inserire allegati su questa pagina.