머신러닝
사용자 행동 인식 데이터 세트
haventmetyou
2024. 1. 17. 16:36
라이브러리 로드¶
In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
데이터 로드¶
In [6]:
# https://archive.ics.uci.edu/dataset/240/human+activity+recognition+using+smartphones
# features.txt 파일에는 피처 이름 index와 피처명이 공백으로 분리되어 있음, 이를 DataFrame으로 로드
feature_name_df = pd.read_csv('./human_activity/features.txt', sep='\s+',
header=None, names=['column_index', 'column_name'])
# 피처명 index를 제거하고 피처명만 리스트 객체로 생성한 뒤 샘플로 10개만 추출
feature_name = feature_name_df.iloc[:, 1].values.tolist()
print('전체 피처명에서 10개만 추출:', feature_name[:10])
전체 피처명에서 10개만 추출: ['tBodyAcc-mean()-X', 'tBodyAcc-mean()-Y', 'tBodyAcc-mean()-Z', 'tBodyAcc-std()-X', 'tBodyAcc-std()-Y', 'tBodyAcc-std()-Z', 'tBodyAcc-mad()-X', 'tBodyAcc-mad()-Y', 'tBodyAcc-mad()-Z', 'tBodyAcc-max()-X']
- 피처명을 보면 인체의 움직임과 관련된 속성의 평균 / 표준편차가 X, Y, Z 축 값으로 되어 있음
- 위에서 피처명을 가지고 있는 features.txt 파일은 중복된 피처명을 가지고 있어 전처리가 필요함
- 중복된 피처명에 대해서는 원본 피처명에 _1 또는 _2를 추가로 부여해 변경한 뒤 이를 이용해 데이터를 DataFrame에 로드
데이터 전처리¶
In [7]:
feature_dup_df = feature_name_df.groupby('column_name').count()
print(feature_dup_df[feature_dup_df['column_index'] > 1].count())
feature_dup_df[feature_dup_df['column_index'] > 1].head()
column_index 42
dtype: int64
Out[7]:
column_index | |
---|---|
column_name | |
fBodyAcc-bandsEnergy()-1,16 | 3 |
fBodyAcc-bandsEnergy()-1,24 | 3 |
fBodyAcc-bandsEnergy()-1,8 | 3 |
fBodyAcc-bandsEnergy()-17,24 | 3 |
fBodyAcc-bandsEnergy()-17,32 | 3 |
총 42개의 피처명이 중복돼 있음. 중복된 피처명을 처리하기 위해 원본 피처명에 _1 또는 _2를 추가로 부여해 새로운 피처명을 가지는 DataFrame을 반환하는 함수 생성.
In [8]:
def get_new_feature_name_df(old_feature_name_df):
feature_dup_df = pd.DataFrame(data=old_feature_name_df.groupby('column_name').cumcount(), columns=['dup_cnt'])
feature_dup_df = feature_dup_df.reset_index()
new_feature_name_df = pd.merge(old_feature_name_df.reset_index(), feature_dup_df, how='outer')
new_feature_name_df['column_name'] = new_feature_name_df[['column_name', 'dup_cnt']].apply(lambda x: x[0]+'_'+str(x[1])
if x[1] > 0 else x[0], axis=1)
new_feature_name_df = new_feature_name_df.drop(['index'], axis=1)
return new_feature_name_df
In [9]:
import pandas as pd
def get_human_dataset():
# 각 데이터 파일을 공백으로 분리되어 있으므로 read_csv에서 공백 문자를 sep으로 할당
feature_name_df = pd.read_csv('./human_activity/features.txt', sep='\s+',
header=None, names=['column_index', 'column_name'])
# 중복된 피처명을 수정하는 get_new_feature_name_df()를 이용, 신규 피처명 DataFrame 생성
new_feature_name_df = get_new_feature_name_df(feature_name_df)
# DataFrame에 피처명을 칼럼으로 부여하기 위해 리스트 객체로 다시 변환
feature_name = new_feature_name_df.iloc[:, 1].values.tolist()
# 학습 피처 데이터 세트와 테스트 피처 데이터를 DataFrame으로 로딩, 칼럼명은 feature_name 적용
X_train = pd.read_csv('./human_activity/train/X_train.txt', sep='\s+', names=feature_name)
X_test = pd.read_csv('./human_activity/test/X_test.txt', sep='\s+', names=feature_name)
# 학습 레이블과 테스트 레이블 데이터를 DataFrame으로 로딩, 칼럼명은 action 부여
y_train = pd.read_csv('./human_activity/train/y_train.txt', sep='\s+', header=None, names=['action'])
y_test = pd.read_csv('./human_activity/test/y_test.txt', sep='\s+', header=None, names=['action'])
# 로드된 학습/테스트용 DataFrame을 모두 반환
return X_train, X_test, y_train, y_test
X_train, X_test, y_train, y_test = get_human_dataset()
In [10]:
print('## 학습 피처 데이터셋 info()')
X_train.info()
## 학습 피처 데이터셋 info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7352 entries, 0 to 7351
Columns: 561 entries, tBodyAcc-mean()-X to angle(Z,gravityMean)
dtypes: float64(561)
memory usage: 31.5 MB
학습 데이터 세트는 7352개의 레코드로 561개의 피처를 가지고 있음. 피처는 전부 float 형의 숫자형이므로 카테고리 인코딩은 필요하지 않음.
In [11]:
X_train.head()
Out[11]:
tBodyAcc-mean()-X | tBodyAcc-mean()-Y | tBodyAcc-mean()-Z | tBodyAcc-std()-X | tBodyAcc-std()-Y | tBodyAcc-std()-Z | tBodyAcc-mad()-X | tBodyAcc-mad()-Y | tBodyAcc-mad()-Z | tBodyAcc-max()-X | tBodyAcc-max()-Y | tBodyAcc-max()-Z | tBodyAcc-min()-X | tBodyAcc-min()-Y | tBodyAcc-min()-Z | tBodyAcc-sma() | tBodyAcc-energy()-X | tBodyAcc-energy()-Y | tBodyAcc-energy()-Z | tBodyAcc-iqr()-X | tBodyAcc-iqr()-Y | tBodyAcc-iqr()-Z | tBodyAcc-entropy()-X | tBodyAcc-entropy()-Y | tBodyAcc-entropy()-Z | tBodyAcc-arCoeff()-X,1 | tBodyAcc-arCoeff()-X,2 | tBodyAcc-arCoeff()-X,3 | tBodyAcc-arCoeff()-X,4 | tBodyAcc-arCoeff()-Y,1 | tBodyAcc-arCoeff()-Y,2 | tBodyAcc-arCoeff()-Y,3 | tBodyAcc-arCoeff()-Y,4 | tBodyAcc-arCoeff()-Z,1 | tBodyAcc-arCoeff()-Z,2 | tBodyAcc-arCoeff()-Z,3 | tBodyAcc-arCoeff()-Z,4 | tBodyAcc-correlation()-X,Y | tBodyAcc-correlation()-X,Z | tBodyAcc-correlation()-Y,Z | tGravityAcc-mean()-X | tGravityAcc-mean()-Y | tGravityAcc-mean()-Z | tGravityAcc-std()-X | tGravityAcc-std()-Y | tGravityAcc-std()-Z | tGravityAcc-mad()-X | tGravityAcc-mad()-Y | tGravityAcc-mad()-Z | tGravityAcc-max()-X | tGravityAcc-max()-Y | tGravityAcc-max()-Z | tGravityAcc-min()-X | tGravityAcc-min()-Y | tGravityAcc-min()-Z | tGravityAcc-sma() | tGravityAcc-energy()-X | tGravityAcc-energy()-Y | tGravityAcc-energy()-Z | tGravityAcc-iqr()-X | ... | fBodyGyro-bandsEnergy()-25,48_2 | fBodyAccMag-mean() | fBodyAccMag-std() | fBodyAccMag-mad() | fBodyAccMag-max() | fBodyAccMag-min() | fBodyAccMag-sma() | fBodyAccMag-energy() | fBodyAccMag-iqr() | fBodyAccMag-entropy() | fBodyAccMag-maxInds | fBodyAccMag-meanFreq() | fBodyAccMag-skewness() | fBodyAccMag-kurtosis() | fBodyBodyAccJerkMag-mean() | fBodyBodyAccJerkMag-std() | fBodyBodyAccJerkMag-mad() | fBodyBodyAccJerkMag-max() | fBodyBodyAccJerkMag-min() | fBodyBodyAccJerkMag-sma() | fBodyBodyAccJerkMag-energy() | fBodyBodyAccJerkMag-iqr() | fBodyBodyAccJerkMag-entropy() | fBodyBodyAccJerkMag-maxInds | fBodyBodyAccJerkMag-meanFreq() | fBodyBodyAccJerkMag-skewness() | fBodyBodyAccJerkMag-kurtosis() | fBodyBodyGyroMag-mean() | fBodyBodyGyroMag-std() | fBodyBodyGyroMag-mad() | fBodyBodyGyroMag-max() | fBodyBodyGyroMag-min() | fBodyBodyGyroMag-sma() | fBodyBodyGyroMag-energy() | fBodyBodyGyroMag-iqr() | fBodyBodyGyroMag-entropy() | fBodyBodyGyroMag-maxInds | fBodyBodyGyroMag-meanFreq() | fBodyBodyGyroMag-skewness() | fBodyBodyGyroMag-kurtosis() | fBodyBodyGyroJerkMag-mean() | fBodyBodyGyroJerkMag-std() | fBodyBodyGyroJerkMag-mad() | fBodyBodyGyroJerkMag-max() | fBodyBodyGyroJerkMag-min() | fBodyBodyGyroJerkMag-sma() | fBodyBodyGyroJerkMag-energy() | fBodyBodyGyroJerkMag-iqr() | fBodyBodyGyroJerkMag-entropy() | fBodyBodyGyroJerkMag-maxInds | fBodyBodyGyroJerkMag-meanFreq() | fBodyBodyGyroJerkMag-skewness() | fBodyBodyGyroJerkMag-kurtosis() | angle(tBodyAccMean,gravity) | angle(tBodyAccJerkMean),gravityMean) | angle(tBodyGyroMean,gravityMean) | angle(tBodyGyroJerkMean,gravityMean) | angle(X,gravityMean) | angle(Y,gravityMean) | angle(Z,gravityMean) | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.288585 | -0.020294 | -0.132905 | -0.995279 | -0.983111 | -0.913526 | -0.995112 | -0.983185 | -0.923527 | -0.934724 | -0.567378 | -0.744413 | 0.852947 | 0.685845 | 0.814263 | -0.965523 | -0.999945 | -0.999863 | -0.994612 | -0.994231 | -0.987614 | -0.943220 | -0.407747 | -0.679338 | -0.602122 | 0.929294 | -0.853011 | 0.359910 | -0.058526 | 0.256892 | -0.224848 | 0.264106 | -0.095246 | 0.278851 | -0.465085 | 0.491936 | -0.190884 | 0.376314 | 0.435129 | 0.660790 | 0.963396 | -0.140840 | 0.115375 | -0.985250 | -0.981708 | -0.877625 | -0.985001 | -0.984416 | -0.894677 | 0.892055 | -0.161265 | 0.124660 | 0.977436 | -0.123213 | 0.056483 | -0.375426 | 0.899469 | -0.970905 | -0.975510 | -0.984325 | ... | -0.999959 | -0.952155 | -0.956134 | -0.948870 | -0.974321 | -0.925722 | -0.952155 | -0.998285 | -0.973273 | -0.646376 | -0.793103 | -0.088436 | -0.436471 | -0.796840 | -0.993726 | -0.993755 | -0.991976 | -0.993365 | -0.988175 | -0.993726 | -0.999918 | -0.991364 | -1.0 | -0.936508 | 0.346989 | -0.516080 | -0.802760 | -0.980135 | -0.961309 | -0.973653 | -0.952264 | -0.989498 | -0.980135 | -0.999240 | -0.992656 | -0.701291 | -1.000000 | -0.128989 | 0.586156 | 0.374605 | -0.991990 | -0.990697 | -0.989941 | -0.992448 | -0.991048 | -0.991990 | -0.999937 | -0.990458 | -0.871306 | -1.000000 | -0.074323 | -0.298676 | -0.710304 | -0.112754 | 0.030400 | -0.464761 | -0.018446 | -0.841247 | 0.179941 | -0.058627 |
1 | 0.278419 | -0.016411 | -0.123520 | -0.998245 | -0.975300 | -0.960322 | -0.998807 | -0.974914 | -0.957686 | -0.943068 | -0.557851 | -0.818409 | 0.849308 | 0.685845 | 0.822637 | -0.981930 | -0.999991 | -0.999788 | -0.998405 | -0.999150 | -0.977866 | -0.948225 | -0.714892 | -0.500930 | -0.570979 | 0.611627 | -0.329549 | 0.284213 | 0.284595 | 0.115705 | -0.090963 | 0.294310 | -0.281211 | 0.085988 | -0.022153 | -0.016657 | -0.220643 | -0.013429 | -0.072692 | 0.579382 | 0.966561 | -0.141551 | 0.109379 | -0.997411 | -0.989447 | -0.931639 | -0.997884 | -0.989614 | -0.933240 | 0.892060 | -0.161343 | 0.122586 | 0.984520 | -0.114893 | 0.102764 | -0.383430 | 0.907829 | -0.970583 | -0.978500 | -0.999188 | ... | -0.999971 | -0.980857 | -0.975866 | -0.975777 | -0.978226 | -0.986911 | -0.980857 | -0.999472 | -0.984479 | -0.816674 | -1.000000 | -0.044150 | -0.122040 | -0.449522 | -0.990335 | -0.991960 | -0.989732 | -0.994489 | -0.989549 | -0.990335 | -0.999867 | -0.991134 | -1.0 | -0.841270 | 0.532061 | -0.624871 | -0.900160 | -0.988296 | -0.983322 | -0.982659 | -0.986321 | -0.991829 | -0.988296 | -0.999811 | -0.993979 | -0.720683 | -0.948718 | -0.271958 | -0.336310 | -0.720015 | -0.995854 | -0.996399 | -0.995442 | -0.996866 | -0.994440 | -0.995854 | -0.999981 | -0.994544 | -1.000000 | -1.000000 | 0.158075 | -0.595051 | -0.861499 | 0.053477 | -0.007435 | -0.732626 | 0.703511 | -0.844788 | 0.180289 | -0.054317 |
2 | 0.279653 | -0.019467 | -0.113462 | -0.995380 | -0.967187 | -0.978944 | -0.996520 | -0.963668 | -0.977469 | -0.938692 | -0.557851 | -0.818409 | 0.843609 | 0.682401 | 0.839344 | -0.983478 | -0.999969 | -0.999660 | -0.999470 | -0.997130 | -0.964810 | -0.974675 | -0.592235 | -0.485821 | -0.570979 | 0.273025 | -0.086309 | 0.337202 | -0.164739 | 0.017150 | -0.074507 | 0.342256 | -0.332564 | 0.239281 | -0.136204 | 0.173863 | -0.299493 | -0.124698 | -0.181105 | 0.608900 | 0.966878 | -0.142010 | 0.101884 | -0.999574 | -0.992866 | -0.992917 | -0.999635 | -0.992605 | -0.992934 | 0.892401 | -0.163711 | 0.094566 | 0.986770 | -0.114893 | 0.102764 | -0.401602 | 0.908668 | -0.970368 | -0.981672 | -0.999679 | ... | -0.999956 | -0.987795 | -0.989015 | -0.985594 | -0.993062 | -0.989836 | -0.987795 | -0.999807 | -0.989237 | -0.907014 | -0.862069 | 0.257899 | -0.618725 | -0.879685 | -0.989280 | -0.990867 | -0.987274 | -0.993179 | -0.999890 | -0.989280 | -0.999845 | -0.986658 | -1.0 | -0.904762 | 0.660795 | -0.724697 | -0.928539 | -0.989255 | -0.986028 | -0.984274 | -0.990979 | -0.995703 | -0.989255 | -0.999854 | -0.993238 | -0.736521 | -0.794872 | -0.212728 | -0.535352 | -0.871914 | -0.995031 | -0.995127 | -0.994640 | -0.996060 | -0.995866 | -0.995031 | -0.999973 | -0.993755 | -1.000000 | -0.555556 | 0.414503 | -0.390748 | -0.760104 | -0.118559 | 0.177899 | 0.100699 | 0.808529 | -0.848933 | 0.180637 | -0.049118 |
3 | 0.279174 | -0.026201 | -0.123283 | -0.996091 | -0.983403 | -0.990675 | -0.997099 | -0.982750 | -0.989302 | -0.938692 | -0.576159 | -0.829711 | 0.843609 | 0.682401 | 0.837869 | -0.986093 | -0.999976 | -0.999736 | -0.999504 | -0.997180 | -0.983799 | -0.986007 | -0.627446 | -0.850930 | -0.911872 | 0.061436 | 0.074840 | 0.198204 | -0.264307 | 0.072545 | -0.155320 | 0.323154 | -0.170813 | 0.294938 | -0.306081 | 0.482148 | -0.470129 | -0.305693 | -0.362654 | 0.507459 | 0.967615 | -0.143976 | 0.099850 | -0.996646 | -0.981393 | -0.978476 | -0.996457 | -0.980962 | -0.978456 | 0.893817 | -0.163711 | 0.093425 | 0.986821 | -0.121336 | 0.095753 | -0.400278 | 0.910621 | -0.969400 | -0.982420 | -0.995976 | ... | -0.999952 | -0.987519 | -0.986742 | -0.983524 | -0.990230 | -0.998185 | -0.987519 | -0.999770 | -0.983215 | -0.907014 | -1.000000 | 0.073581 | -0.468422 | -0.756494 | -0.992769 | -0.991700 | -0.989055 | -0.994455 | -0.995562 | -0.992769 | -0.999895 | -0.988055 | -1.0 | 1.000000 | 0.678921 | -0.701131 | -0.909639 | -0.989413 | -0.987836 | -0.986850 | -0.986749 | -0.996199 | -0.989413 | -0.999876 | -0.989136 | -0.720891 | -1.000000 | -0.035684 | -0.230091 | -0.511217 | -0.995221 | -0.995237 | -0.995722 | -0.995273 | -0.995732 | -0.995221 | -0.999974 | -0.995226 | -0.955696 | -0.936508 | 0.404573 | -0.117290 | -0.482845 | -0.036788 | -0.012892 | 0.640011 | -0.485366 | -0.848649 | 0.181935 | -0.047663 |
4 | 0.276629 | -0.016570 | -0.115362 | -0.998139 | -0.980817 | -0.990482 | -0.998321 | -0.979672 | -0.990441 | -0.942469 | -0.569174 | -0.824705 | 0.849095 | 0.683250 | 0.837869 | -0.992653 | -0.999991 | -0.999856 | -0.999757 | -0.998004 | -0.981232 | -0.991325 | -0.786553 | -0.559477 | -0.761434 | 0.313276 | -0.131208 | 0.191161 | 0.086904 | 0.257615 | -0.272505 | 0.434728 | -0.315375 | 0.439744 | -0.269069 | 0.179414 | -0.088952 | -0.155804 | -0.189763 | 0.599213 | 0.968224 | -0.148750 | 0.094486 | -0.998429 | -0.988098 | -0.978745 | -0.998411 | -0.988654 | -0.978936 | 0.893817 | -0.166786 | 0.091682 | 0.987434 | -0.121834 | 0.094059 | -0.400477 | 0.912235 | -0.967051 | -0.984363 | -0.998318 | ... | -0.999973 | -0.993591 | -0.990063 | -0.992324 | -0.990506 | -0.987805 | -0.993591 | -0.999873 | -0.997343 | -0.907014 | -1.000000 | 0.394310 | -0.112663 | -0.481805 | -0.995523 | -0.994389 | -0.993305 | -0.995485 | -0.982177 | -0.995523 | -0.999941 | -0.994169 | -1.0 | -1.000000 | 0.559058 | -0.528901 | -0.858933 | -0.991433 | -0.989059 | -0.987744 | -0.991462 | -0.998353 | -0.991433 | -0.999902 | -0.989321 | -0.763372 | -0.897436 | -0.273582 | -0.510282 | -0.830702 | -0.995093 | -0.995465 | -0.995279 | -0.995609 | -0.997418 | -0.995093 | -0.999974 | -0.995487 | -1.000000 | -0.936508 | 0.087753 | -0.351471 | -0.699205 | 0.123320 | 0.122542 | 0.693578 | -0.615971 | -0.847865 | 0.185151 | -0.043892 |
5 rows × 561 columns
- 많은 칼럼들의 대부분이 움직임 위치와 관련된 속성
In [13]:
print(y_train['action'].value_counts())
action
6 1407
5 1374
4 1286
1 1226
2 1073
3 986
Name: count, dtype: int64
- 레이블 값은 1, 2, 3, 4, 5, 6의 6개 값이고 분포도는 특정 값으로 왜곡되지 않고 비교적 고르게 분포되어 있음
동작 예측 분류¶
In [14]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
# 예제 반복 시마다 동일한 예측 결과 도출을 위해 random_state 설정
dt_clf = DecisionTreeClassifier(random_state=156)
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
accuracy = accuracy_score(y_test, pred)
print('결정 트리 예측 정확도: {0:.4f}'.format(accuracy))
# DecisionTreeClassifier의 하이퍼 파라미터 추출
print('DecisionTreeClassifier 기본 하이퍼 파라미터:\n', dt_clf.get_params())
Out[14]:
DecisionTreeClassifier(random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(random_state=156)
결정 트리 예측 정확도: 0.8548
DecisionTreeClassifier 기본 하이퍼 파라미터:
{'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'random_state': 156, 'splitter': 'best'}
- 약 85.48%의 정확도를 나타내고 있음
- 결정 트리의 트리 깊이(Tree Depth)가 예측 정확도에 주는 영향을 살펴보기
- 결정 트리의 경우 분류를 위해 리프 노드(클래스 결정 노드)가 될 수 있는 적합한 수준이 될 때까지 지속해서 트리의 분할을 수행하며 깊이가 깊어짐
- GridSearchCV를 이용해 사이킷런 결정 트리의 깊이를 조절할 수 있는 하이퍼 파라미터인 max_depth 값을 변화시키면서 예측 성능을 확인
GridSearchCV로 결정 트리 깊이 조절¶
In [15]:
from sklearn.model_selection import GridSearchCV
params = {
'max_depth' : [6, 8, 10, 12, 16, 20, 24],
'min_samples_split' : [16]
}
grid_cv = GridSearchCV(dt_clf, param_grid=params, scoring='accuracy', cv=5, verbose=1) # verbose에 옵션 값을 주면 자세한 출력, 0으로 설정하면 기본적인 결과 정보면 출력
grid_cv.fit(X_train, y_train)
print('GridSearchCV 최고 평균 정확도 수치: {0:.4f}'.format(grid_cv.best_score_))
print('GridSearchCV 최적 하이퍼 파라미터:', grid_cv.best_params_)
Fitting 5 folds for each of 7 candidates, totalling 35 fits
Out[15]:
GridSearchCV(cv=5, estimator=DecisionTreeClassifier(random_state=156),
param_grid={'max_depth': [6, 8, 10, 12, 16, 20, 24],
'min_samples_split': [16]},
scoring='accuracy', verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=5, estimator=DecisionTreeClassifier(random_state=156),
param_grid={'max_depth': [6, 8, 10, 12, 16, 20, 24],
'min_samples_split': [16]},
scoring='accuracy', verbose=1)
DecisionTreeClassifier(random_state=156)
DecisionTreeClassifier(random_state=156)
GridSearchCV 최고 평균 정확도 수치: 0.8549
GridSearchCV 최적 하이퍼 파라미터: {'max_depth': 8, 'min_samples_split': 16}
- max_depth가 8일 때 5개의 폴드 세트의 최고 평균 정확도 결과가 약 85.49%로 도출
수행 목표: max_depth 값의 증가에 따라 어떻게 예측 성능이 변했는지 확인하는 것
In [16]:
# GridSearchCV 객체의 cv_results_ 속성을 DataFrame으로 생성
# cv_results_ 속성은 CV 세트에 하이퍼 파라미터를 순차적으로 입력했을 때의 성능 수치를 가지고 있음
cv_results_df = pd.DataFrame(grid_cv.cv_results_)
# max_depth 파라미터 값과 그때의 테스트 세트, 학습 데이터 세트의 정확도 수치 추출
cv_results_df[['param_max_depth', 'mean_test_score']]
Out[16]:
param_max_depth | mean_test_score | |
---|---|---|
0 | 6 | 0.847662 |
1 | 8 | 0.854879 |
2 | 10 | 0.852705 |
3 | 12 | 0.845768 |
4 | 16 | 0.847127 |
5 | 20 | 0.848624 |
6 | 24 | 0.848624 |
- score는 max_depth가 9일 때 0.854로 정확도가 정점, 이를 넘어가면서 정확도는 계속 떨어짐
- 결정 트리는 더 완벽한 규칙을 학습 데이터 세트에 적용하기 위해 노드를 지속적으로 분할하며 깊이가 깊어지고, 더욱 복잡한 모델이 됨
- 이로 인해 과적합으로 인한 성능 저하를 유발하게 될 수 있음
In [17]:
# 테스트 세트에서 결정 트리 정확도 측정
max_depths = [6, 8, 10, 12, 16, 20, 24]
# max_depth 값을 증가시키면서 그때마다 학습과 테스트 세트에서의 예측 성능 측정, min_samples_split 값은 16으로 고정
for depth in max_depths:
dt_clf = DecisionTreeClassifier(max_depth=depth, min_samples_split=16, random_state=156)
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
accuracy = accuracy_score(y_test, pred)
print('max_depth = {0} 정확도: {1:.4f}'.format(depth, accuracy))
Out[17]:
DecisionTreeClassifier(max_depth=6, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=6, min_samples_split=16, random_state=156)
max_depth = 6 정확도: 0.8551
Out[17]:
DecisionTreeClassifier(max_depth=8, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=8, min_samples_split=16, random_state=156)
max_depth = 8 정확도: 0.8717
Out[17]:
DecisionTreeClassifier(max_depth=10, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=10, min_samples_split=16, random_state=156)
max_depth = 10 정확도: 0.8599
Out[17]:
DecisionTreeClassifier(max_depth=12, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=12, min_samples_split=16, random_state=156)
max_depth = 12 정확도: 0.8571
Out[17]:
DecisionTreeClassifier(max_depth=16, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=16, min_samples_split=16, random_state=156)
max_depth = 16 정확도: 0.8599
Out[17]:
DecisionTreeClassifier(max_depth=20, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=20, min_samples_split=16, random_state=156)
max_depth = 20 정확도: 0.8565
Out[17]:
DecisionTreeClassifier(max_depth=24, min_samples_split=16, random_state=156)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=24, min_samples_split=16, random_state=156)
max_depth = 24 정확도: 0.8565
- max_depth가 8일 경우 약 87.17%로 가장 높은 정확도를 나타내고, max_depth가 8을 넘어가면서 정확도는 계속 감소
- 결정 트리는 깊이가 깊어질수록 과적합의 영향력이 커지므로 하이퍼 파라미터를 이용해 깊이를 제어할 수 있어야 함
- 복잡한 모델보다 깊이를 낮춘 단순한 모델이 더 효과적인 결과를 가져올 수 있음
정확도 성능 튜닝¶
max_depth와 min_samples_split을 같이 변경
In [18]:
params = {
'max_depth' : [8, 12, 16, 20],
'min_samples_split' : [16, 24],
}
grid_cv = GridSearchCV(dt_clf, param_grid=params, scoring='accuracy', cv=5, verbose=1)
grid_cv.fit(X_train, y_train)
print('GridSearchCV 최고 평균 정확도 수치: {0:.4f}'.format(grid_cv.best_score_))
print('GridSearchCV 최적 하이퍼 파라미터:', grid_cv.best_params_)
Fitting 5 folds for each of 8 candidates, totalling 40 fits
Out[18]:
GridSearchCV(cv=5,
estimator=DecisionTreeClassifier(max_depth=24,
min_samples_split=16,
random_state=156),
param_grid={'max_depth': [8, 12, 16, 20],
'min_samples_split': [16, 24]},
scoring='accuracy', verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=5,
estimator=DecisionTreeClassifier(max_depth=24,
min_samples_split=16,
random_state=156),
param_grid={'max_depth': [8, 12, 16, 20],
'min_samples_split': [16, 24]},
scoring='accuracy', verbose=1)
DecisionTreeClassifier(max_depth=24, min_samples_split=16, random_state=156)
DecisionTreeClassifier(max_depth=24, min_samples_split=16, random_state=156)
GridSearchCV 최고 평균 정확도 수치: 0.8549
GridSearchCV 최적 하이퍼 파라미터: {'max_depth': 8, 'min_samples_split': 16}
- max_depth가 8, min_samples_split이 16일 때 가장 최고의 정확도로 약 85.49%를 나타냄
최적 하이퍼 파라미터 적용¶
In [19]:
# grid_cv.best_estimator_는 최적 하이퍼 파라미터인 max_depth 8, min_samples_split 16으로 학습이 완료된 Estimator 객체
best_df_clf = grid_cv.best_estimator_
pred1 = best_df_clf.predict(X_test) # 테스트 데이터 세트에 예측 수행
accuracy = accuracy_score(y_test, pred1)
print('결정 트리 예측 정확도: {0:.4f}'.format(accuracy))
결정 트리 예측 정확도: 0.8717
- max_depth 8, min_samples_split 16일 때 테스트 데이터 세트의 예측 정확도는 약 87.17%
각 피처의 중요도 Top 20 시각화¶
In [20]:
import seaborn as sns
ftr_importances_values = best_df_clf.feature_importances_
# Top 중요도로 정렬을 쉽게 하고, Seaborn의 막대 그래프로 쉽게 표현하기 위해 Series 변환
ftr_importances = pd.Series(ftr_importances_values, index=X_train.columns)
# 중요도 값 순으로 Series 정렬
ftr_top20 = ftr_importances.sort_values(ascending=False)[:20]
plt.figure(figsize=(8, 6))
plt.title('Feature importances Top 20')
sns.barplot(x=ftr_top20, y = ftr_top20.index)
Out[20]:
중요도가 높은 순으로 Top 20 피처를 막대그래프로 표현해본 결과, 이중 가장 높은 중요도를 가진 Top 5의 피처들이 매우 중요하게 규칙 생성에 영향을 미치고 있는 것을 알 수 있음