1 분 소요


데이터 불러오기

In [1]:
from sklearn.datasets import fetch_openml
In [2]:
mnist = fetch_openml('mnist_784')

28*28 사이즈의 784개의 자료

In [3]:
type(mnist) # dictionary like
Out [3]:
sklearn.utils.Bunch
In [4]:
mnist.keys()
Out [4]:
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
In [5]:
mnist.data.shape
Out [5]:
(70000, 784)
In [6]:
mnist.target.shape
Out [6]:
(70000,)
In [7]:
# 타깃 분포 확인
mnist.target.value_counts()
Out [7]:
1    7877
7    7293
3    7141
2    6990
9    6958
0    6903
6    6876
8    6825
4    6824
5    6313
Name: class, dtype: int64
In [8]:
mnist.data.min().min(), mnist.data.max().max()
# 0~255 까지의 256색을 표현 -> 이후에 머신러닝이 숫자에 민감한 로직을 사용하면 0~1로 바꾸기 위해 255로 나눔
Out [8]:
(0.0, 255.0)

랜덤 포레스트

In [9]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
In [10]:
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=0.1)
In [11]:
# 나눠진 타깃 분포 확인
y_train.value_counts()
Out [11]:
1    7117
7    6489
3    6432
2    6291
9    6262
0    6223
6    6207
8    6140
4    6131
5    5708
Name: class, dtype: int64
In [12]:
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, pred)
In [13]:
print(f'정확도 : {accuracy:.4f}')
Out [13]:
정확도 : 0.9723

실제 데이터로 확인

In [14]:
import numpy as np
import matplotlib.pyplot as plt
In [15]:
tmp = X_test.iloc[666]
tmp = np.array(tmp)
tmp = tmp.reshape(28, 28)
plt.imshow(tmp, cmap='gray')
y_test.iloc[666]
Out [15]:
'1'

img

In [16]:
import glob
from PIL import Image
In [17]:
for path in glob.glob('./img/*.png'):
    img = Image.open(path).convert('L')
    plt.imshow(img, cmap='gray')
    img = np.resize(img, (1, 784))
    # 흑백 반전, 실수값 전환
    img = 255.0 - (img)
    pred = clf.predict(img)
    print(pred)
    plt.show()
Out [17]:
['9']

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

img

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

['1']

img

['2']

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

img

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

['3']

img

['4']

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

img

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

['4']

img

['6']

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

img

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

['2']

img

['8']

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

img

C:\Users\user\anaconda3\lib\site-packages\sklearn\base.py:450: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
  warnings.warn(

['8']

img

Reference

  • 이 포스트는 SeSAC 인공지능 자연어처리, 컴퓨터비전 기술을 활용한 응용 SW 개발자 양성 과정 - 심선조 강사님의 강의를 정리한 내용입니다.

댓글남기기