4. 课后练习-MNIST 手写训练集
本文最后更新于 2024年1月27日 下午
课后练习3
Tasks
- Familiarize yourself with the MNIST dataset: MNIST handwritten digit
database, Yann LeCun, Corinna Cortes and Chris Burges.
[http://yann.lecun.com/exdb/mnist/]
- Familiarize yourself with sklearn package: scikit-learn: machine learning in Python — scikitlearn 0.24.1 documentation [scikit-learn.org]
Programming exercise
Q1. Use the fetch_openml function found in sklearn.datasets to load the mnist_784 dataset into python. This will load X and y variables for you.
- Print the dimensions of the variables returned by the function.
- Write a python script to find how many distinct values are present in y?
- Select one sample from X for each distinct y value.
- Resize each sample to represent the 28x28 pixel image.
- Display all the selected images in one diagram using subplots in
matplotlib. The following code gives you an example of how to do this,
Solutions:
1
2
3
4
5fig = plt.figure()
for i in range(1, 11):
fig.add_subplot(2, 5, i)
plt.imshow(images[i])
plt.show()or1
2
3
4
5
6from sklearn import datasets
from sklearn,datasets import fetch_openml
images,labels = fetch_openml('mnist_784',version=1, return_x_y=true, as_frame=false)
# load 70000 28x28=784 handwriting images
# print(images.shape)
#>> (7000,784)1
2
3
4
5
6from sklearn import datasets
digits = datasets.load_digits() #load the mnist dataset which already in sklearn
images = digits.images #access 1797 8x8 images in mnist by print(images.shape)
labels = digits.target #access 1797 labes
# print(images.shape)
#>> (1797,8,8)### Q2. Use sklearn to train a digit classifier.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
digits = datasets.load_digits() #load the mnist dataset which already in sklearn
images = digits.images #access 1797 8x8 images in mnist by print(images.shape)
labels = digits.target #access labels
np.unique(labels) # summerize the labels
print(np.unique(labels).shape)
fig = plt.figure()
for i in range(0,10):
fig.add_subplot(2, 5, i+1) # creat a batch of subplot with 2 rows 5 columns
# i means the position in the subplot
plt.imshow(images[i])
plt.show() # display the subplot - Split the X and y into a training set and testing set of 80-20 split.
- Train a Support Vector Machin (SVM) for classification of the digits
using the training set. The following code shows how to train a model
using sklearn.
1
2clf = svm.SVC()
clf.fit(x_train, y_train) - Test the model using the test set.
- Experiment with different parameter values for the SVM and see how it performs. Try changing the gamma value to be [0.0001, 0.0005, 0.001, 0.005, 0.01]
- Plot the accuracy value with respect to the change in gamma above.
- Plot the confusion matrix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40import numpy as np
from sklearn import datasets
import matplot as plt
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import metrics
digits = datasets.load_digits() # load the mnist dataset which already in sklearn
data = digits.images # access 1797 8x8 images in mnist, print(images.shape)
labels = digits.target # access 1797 labels
images = data.reshape((len(data),-1)) # reshaape the 8x8 matrixes into 64x1 vectors
x_train,x_test,y_train,y_test = train_test_split(images,labels, test_size = 0.2, shuffle = false) # 20% will be test set
# x:images y:labels
clf = svm.SVC() # create the svm classifier
clf.fit(x_train, y_train) # fit the data within vectors
acc = clf.score(x_test, y_test) # do the test and retrun the accurancy
disp = metrics.plot_confusion_matrix(clf,x_test,y_test) # add into confusion matrix
print(acc) # print the accurancy
sklearn.metrics.ConfusionMatrixDisplay(disp) # display the confusion matrix
g_ = [0.0001,0.0005,0.001,0.005,0.01] # list of gamma
scores = [] # list of accurancy
for g in g_:
clf = svm.SVC(gamma = g) # create the svm classifier,specify the gamma
clf.fit(x_train, y_train) # fit the data within vectors
acc = clf.score(x_test, y_test) # do the test and retrun the accurancy
scores.append(acc)
print(g_) # print the accurancy
print(scores)
plt.plot(g_, scores)
plt.show()
```
4. 课后练习-MNIST 手写训练集
https://l61012345.top/2021/01/27/Machine Learning-NAU/4.a 课后练习-MNIST/