Machine Learning (8) - 분류 / 결정 트리
분류
개요
앙상블 방법(Ensemble Method)을 집중적으로 볼 것이다.
앙상블은 일반적으로 배깅(Bagging)과 부스팅(Boosting) 방식으로 나뉜다.
배깅에는 대표적으로 랜덤 포레스트(Random Forest)가 있다.
부스팅에는 그래디언트 부스팅(Gradient Boosting)이 있지만 수행시간이 너무 오래 걸린다.
XgBoost(eXtra Gradient Boost)와 LightGBM 등 기존 그래디언트 부스팅을 향상시킨 알고리즘도 있다.
결정 트리
최대한 균일한 데이터 세트를 구상할 수 있도록 분할하는 것이 중요하다.
결정노드는 정보 균일도가 높은 데이터 세트를 먼저 선택할 수 있도록 규칙 조건을 만든다.
이러한 정보의 균일도를 측정하는 대표적인 방법은 엔트로피를 이용한 정보 이득(Information Gain)지수와 지니 계수가 있다.
- 엔트로피란 데이터 집합의 혼잡도를 의미한다. 정보이득은 1-엔트로피로 단순할 수록 값이 높다. 이 정보 이득 지수로 분할 기준을 정한다. 즉, 정보 이득이 높은 속성을 기준으로 분할한다.
- 지니 계수는 경제학에서 불평등 지수를 나타내는 계수이다. 0이 가장 평등하고 1로 갈수록 불평등하다. 머신러닝에 적용될 때는 지니 계수가 낮을수록 데이터 균일도가 높은 것으로 해석해 지니 계수가 낮은 속성을 기준으로 분할한다.
결정 트리 모델의 특징
-
결정 트리 장점
정보의 균일도라는 룰을 기반으로 하고 있어서 알고리즘이 쉽고 직관적이다.
피처의 스케일링이나 정규화 등의 사전 가공 영향도가 크지 않다. -
결정 트리 단점
하지만 과적합으로 정확도가 떨어져 알고리즘 성능이 떨어진다.
하이퍼 파라미터 튜닝을 통해 과적합을 방지하는데 튜닝에 시간이 걸릴 수 있다.
결정 트리 파라미터
-
min_samples_split
노드를 분할하기 위한 최소한의 샘플 데이터 수
default는 2이고 작게 설정할수록 분할되는 노드가 많아져 과적합 가능성이 증가 -
min_samples_leaf
리프 노드가 되기 위한 최소한의 샘플 데이터 수
비대칭적(imbalanced) 데이터의 경우 특정 클래스의 데이터가 극도로 작을 수 있으므로 이 경우는 작게 설정 필요 -
max_features
최적의 분할을 위해 고려할 최대 피처 갯수
default는 None으로 데이터 세트의 모든 피처를 사용해 분할 수행
int형은 피처의 갯수, float형은 퍼센트, ‘sqrt’와 ‘auto’는 sqrt(전체 피처 갯수) ‘log’는 log2(전체 피처 갯수) -
max_depth
트리의 최대 깊이
default는 None으로 노드가 가지는 데이터 갯수가 min_samples_split보다 작아질 때까지 계속 깊이를 증가시킴 깊이가 깊어지면 과적합할 수 있으므로 제어 필요 -
max_leaf_nodes
리프 노드의 최대 갯수
결정 트리 모델의 시각화
Graphviz 패키지를 사용해 어떤 규칙을 가지고 트리를 생성하는지 시작적으로 보여줄 수 있다.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
dt_clf = DecisionTreeClassifier(random_state=156)
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=11)
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier(random_state=156)
from sklearn.tree import export_graphviz
export_graphviz(dt_clf, out_file='iris.dot', class_names=iris.target_names,
feature_names=iris.feature_names, impurity=True, filled=True) # filled: 색상 채우기
import graphviz
with open('iris.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
- max_depth 제약 변경
dt_clf = DecisionTreeClassifier(max_depth=3, random_state=156)
dt_clf.fit(X_train, y_train)
export_graphviz(dt_clf, out_file='iris1.dot', class_names=iris.target_names,
feature_names=iris.feature_names, impurity=True, filled=True)
with open('iris1.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
- min_samples_split 제약 변경
dt_clf = DecisionTreeClassifier(min_samples_split=4, random_state=156)
dt_clf.fit(X_train, y_train)
export_graphviz(dt_clf, out_file='iris1.dot', class_names=iris.target_names,
feature_names=iris.feature_names, impurity=True, filled=True)
with open('iris1.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
- min_samples_leaf 제약 변경
dt_clf = DecisionTreeClassifier(min_samples_leaf=4, random_state=156)
dt_clf.fit(X_train, y_train)
export_graphviz(dt_clf, out_file='iris1.dot', class_names=iris.target_names,
feature_names=iris.feature_names, impurity=True, filled=True)
with open('iris1.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
dt_clf = DecisionTreeClassifier(random_state=156)
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier(random_state=156)
- feathre 별 중요도 확인
dt_clf.feature_importances_
array([0.02500521, 0. , 0.55490281, 0.42009198])
for name, value in zip(iris.feature_names, dt_clf.feature_importances_):
print(f"{name} : {value:.3f}")
sepal length (cm) : 0.025
sepal width (cm) : 0.000
petal length (cm) : 0.555
petal width (cm) : 0.420
import seaborn as sns
sns.barplot(x=dt_clf.feature_importances_, y=iris.feature_names, palette='pastel')
<AxesSubplot:>
타이타닉 데이터
불러와 의사결정트리로 학습하고 feature별 중요도 확인
import pandas as pd
import func01 # 미리 저장한 전처리 함수
df = pd.read_csv('titanic.csv')
y = df['Survived']
X = df.drop(columns=['Survived'])
X = func01.transform_features(X)
['female' 'male']
['A' 'B' 'C' 'D' 'E' 'F' 'G' 'N' 'T']
['C' 'N' 'Q' 'S']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=11)
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier()
export_graphviz(dt_clf, out_file='titanic.dot', class_names=['Died', 'Survived'],
feature_names=X.columns, filled=True)
with open('titanic.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
for name, value in zip(X.columns, dt_clf.feature_importances_):
print(f"{name} : {value:.3f}")
Pclass : 0.081
Sex : 0.284
Age : 0.242
SibSp : 0.048
Parch : 0.019
Fare : 0.242
Cabin : 0.049
Embarked : 0.035
sns.barplot(x=dt_clf.feature_importances_, y=X.columns, palette='hls')
<AxesSubplot:>
결정 트리 과적합(Overfitting)
2개의 피처가 3가지 유형의 클래스 값을 가지는 데이터 세트 생성, 시각화
# 분류를 위한 테스트용 데이터 만들기
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
plt.title("2개의 피처가 3가지 클래스 값을 가지는 데이터 세트", fontdict = {'fontsize' : 12})
# plt.rc('figure', titlesize=10)
# 2차원 시각화를 위해서 피처는 2개, 클래스는 3가지 유형의 분류 샘플 데이터 생성
X, y = make_classification(n_features=2, n_redundant=0, n_classes=3,
n_clusters_per_class=1, random_state=0)
# 그래프 형태로 2개의 피처로 2차원 좌표 시각화, 각 클래스 값은 다른 색깔로 표시됨.
plt.scatter(X[:, 0], X[:, 1], marker='o', c=y, s=25, edgecolor='k') # c:컬러값
plt.grid()
# 어떻게 학습되는지 보기만 하기에 데이터 분리는 하지 않음
dt_clf = DecisionTreeClassifier(random_state=156).fit(X, y)
func01.visualize_boundary(dt_clf, X, y)
C:\Workspace\Python\05_machine_learning\func01.py:54: UserWarning: The following kwargs were not used by contour: 'clim'
contours = ax.contourf(xx, yy, Z, alpha=0.3,
일부 이상치(Outlier) 데이터까지 분류하기 위해 분할이 자주 일어나 결정 기준 경계가 매우 많아지고 복잡해졌다.
이렇게 복잡한 모델은 학습 데이터 세트의 특성과 약간만 다른 형태의 데이터 세트를 예측하면 예측 정확도가 떨어진다.
=> 과대적합(Overfitting)
# min_samples_lea=6으로 트리 생성 조건을 제약한 결정 경계 시각화
dt_clf = DecisionTreeClassifier(min_samples_leaf=6, random_state=156).fit(X, y)
func01.visualize_boundary(dt_clf, X, y)
C:\Workspace\Python\05_machine_learning\func01.py:54: UserWarning: The following kwargs were not used by contour: 'clim'
contours = ax.contourf(xx, yy, Z, alpha=0.3,
이상치에 크게 반응하지 않으면서 좀 더 일반화된 분류 규칙에 따라 분류되었다!
Reference
- 이 포스트는 SeSAC 인공지능 자연어처리, 컴퓨터비전 기술을 활용한 응용 SW 개발자 양성 과정 - 심선조 강사님의 강의를 정리한 내용입니다.
댓글남기기