Kaggle 산탄데르 고객 만족 예측(Santander Customer Satisfaction)
- XGBoost와 LightGBM을 활용해서 예측해보자
- 산탄데르 은행이 주최한 경연이라 피처 이름은 익명 처리되어있음
- 레이블값이 1이면 불만 가진 고객, 0이면 만족한 고객
- 모델 성능 평가는 roc-auc로 평가
https://www.kaggle.com/competitions/santander-customer-satisfaction/data
데이터 전처리
In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
cust_df = pd.read_csv('/content/train.csv', encoding='latin-1')
print('dataset shape:', cust_df.shape)
cust_df.head(3)
dataset shape: (76020, 371)
Out[3]:
ID | var3 | var15 | imp_ent_var16_ult1 | imp_op_var39_comer_ult1 | imp_op_var39_comer_ult3 | imp_op_var40_comer_ult1 | imp_op_var40_comer_ult3 | imp_op_var40_efect_ult1 | imp_op_var40_efect_ult3 | ... | saldo_medio_var33_hace2 | saldo_medio_var33_hace3 | saldo_medio_var33_ult1 | saldo_medio_var33_ult3 | saldo_medio_var44_hace2 | saldo_medio_var44_hace3 | saldo_medio_var44_ult1 | saldo_medio_var44_ult3 | var38 | TARGET | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 2 | 23 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 39205.17 | 0 |
1 | 3 | 2 | 34 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 49278.03 | 0 |
2 | 4 | 2 | 23 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 67333.77 | 0 |
3 rows × 371 columns
In [4]:
cust_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 76020 entries, 0 to 76019
Columns: 371 entries, ID to TARGET
dtypes: float64(111), int64(260)
memory usage: 215.2 MB
- 전체 데이터에서 만족, 불만족의 비율 살펴보기
In [5]:
cust_df['TARGET'].value_counts()
Out[5]:
0 73012
1 3008
Name: TARGET, dtype: int64
In [6]:
unsatisfied_cnt = cust_df[cust_df['TARGET']==1].TARGET.count()
total_cnt = cust_df.TARGET.count()
print('비율은 {0:.2f}'.format((unsatisfied_cnt/total_cnt)))
비율은 0.04
In [7]:
cust_df.describe()
Out[7]:
ID | var3 | var15 | imp_ent_var16_ult1 | imp_op_var39_comer_ult1 | imp_op_var39_comer_ult3 | imp_op_var40_comer_ult1 | imp_op_var40_comer_ult3 | imp_op_var40_efect_ult1 | imp_op_var40_efect_ult3 | ... | saldo_medio_var33_hace2 | saldo_medio_var33_hace3 | saldo_medio_var33_ult1 | saldo_medio_var33_ult3 | saldo_medio_var44_hace2 | saldo_medio_var44_hace3 | saldo_medio_var44_ult1 | saldo_medio_var44_ult3 | var38 | TARGET | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | ... | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 76020.000000 | 7.602000e+04 | 76020.000000 |
mean | 75964.050723 | -1523.199277 | 33.212865 | 86.208265 | 72.363067 | 119.529632 | 3.559130 | 6.472698 | 0.412946 | 0.567352 | ... | 7.935824 | 1.365146 | 12.215580 | 8.784074 | 31.505324 | 1.858575 | 76.026165 | 56.614351 | 1.172358e+05 | 0.039569 |
std | 43781.947379 | 39033.462364 | 12.956486 | 1614.757313 | 339.315831 | 546.266294 | 93.155749 | 153.737066 | 30.604864 | 36.513513 | ... | 455.887218 | 113.959637 | 783.207399 | 538.439211 | 2013.125393 | 147.786584 | 4040.337842 | 2852.579397 | 1.826646e+05 | 0.194945 |
min | 1.000000 | -999999.000000 | 5.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 5.163750e+03 | 0.000000 |
25% | 38104.750000 | 2.000000 | 23.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 6.787061e+04 | 0.000000 |
50% | 76043.000000 | 2.000000 | 28.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.064092e+05 | 0.000000 |
75% | 113748.750000 | 2.000000 | 40.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.187563e+05 | 0.000000 |
max | 151838.000000 | 238.000000 | 105.000000 | 210000.000000 | 12888.030000 | 21024.810000 | 8237.820000 | 11073.570000 | 6600.000000 | 6600.000000 | ... | 50003.880000 | 20385.720000 | 138831.630000 | 91778.730000 | 438329.220000 | 24650.010000 | 681462.900000 | 397884.300000 | 2.203474e+07 | 1.000000 |
8 rows × 371 columns
In [8]:
cust_df['var3'].value_counts()
Out[8]:
2 74165
8 138
-999999 116
9 110
3 108
...
231 1
188 1
168 1
135 1
87 1
Name: var3, Length: 208, dtype: int64
- var3의 min값이 -999999임 >> 116개 있음
- 다른 값에 비해 편차가 심하므로 값이 가장 많은 2로 변환하자
In [9]:
cust_df['var3'].replace(-999999, 2, inplace=True)
cust_df.drop('ID', axis=1, inplace=True)
In [10]:
X_features = cust_df.iloc[:, :-1]
y_labels = cust_df.iloc[:, -1]
print(X_features.shape)
(76020, 369)
- 데이터 분리하기
- 데이터셋이 비대칭이므로 타겟값 분포도가 학습용, 테스트용에 비슷하게 추출됐는지 확인하기
In [11]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=0.2, random_state=0)
train_cnt = y_train.count()
test_cnt = y_test.count()
print(X_train.shape, X_test.shape)
print(y_train.value_counts()/train_cnt)
print(y_test.value_counts()/test_cnt)
(60816, 369) (15204, 369)
0 0.960964
1 0.039036
Name: TARGET, dtype: float64
0 0.9583
1 0.0417
Name: TARGET, dtype: float64
- 원본데이터와 유사하게 타겟값의 분포가 불만족값 4% 정도로 만들어짐
- 조기중단의 검증데이터셋으로 사용하기 위해 다시 한번 데이터 쪼개기
In [12]:
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.3, random_state=0)
XGBoost 모델 학습, 하이퍼 파라미터 튜닝
- 사이킷런 래퍼 XGBoost기반으로 학습 수행하자
- 조기중단은 100회로 설정
In [15]:
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score
xgb_clf = XGBClassifier(n_estimators=500, learning_rate=0.05, random_state=156)
xgb_clf.fit(X_tr, y_tr, early_stopping_rounds=100, eval_metric='auc', eval_set=[(X_tr, y_tr), (X_val, y_val)])
xgb_roc_score = roc_auc_score(y_test, xgb_clf.predict_proba(X_test)[:, 1])
print('ROC AUC: {0:.4f}'.format(xgb_roc_score))
[0] validation_0-auc:0.82179 validation_1-auc:0.80068
[1] validation_0-auc:0.82347 validation_1-auc:0.80523
[2] validation_0-auc:0.83178 validation_1-auc:0.81097
[3] validation_0-auc:0.83401 validation_1-auc:0.81091
[4] validation_0-auc:0.83443 validation_1-auc:0.81040
[5] validation_0-auc:0.83570 validation_1-auc:0.81089
[6] validation_0-auc:0.83597 validation_1-auc:0.81057
[7] validation_0-auc:0.83643 validation_1-auc:0.81082
[8] validation_0-auc:0.83682 validation_1-auc:0.81147
[9] validation_0-auc:0.83769 validation_1-auc:0.81188
[10] validation_0-auc:0.83770 validation_1-auc:0.81163
[11] validation_0-auc:0.83911 validation_1-auc:0.81355
[12] validation_0-auc:0.83976 validation_1-auc:0.81336
[13] validation_0-auc:0.84038 validation_1-auc:0.81365
[14] validation_0-auc:0.84176 validation_1-auc:0.81419
[15] validation_0-auc:0.84306 validation_1-auc:0.81586
[16] validation_0-auc:0.84343 validation_1-auc:0.81610
[17] validation_0-auc:0.84373 validation_1-auc:0.81584
[18] validation_0-auc:0.84542 validation_1-auc:0.81581
[19] validation_0-auc:0.84580 validation_1-auc:0.81622
[20] validation_0-auc:0.84656 validation_1-auc:0.81641
[21] validation_0-auc:0.84732 validation_1-auc:0.81740
[22] validation_0-auc:0.84769 validation_1-auc:0.81748
[23] validation_0-auc:0.84833 validation_1-auc:0.81744
[24] validation_0-auc:0.84836 validation_1-auc:0.81704
[25] validation_0-auc:0.84892 validation_1-auc:0.81650
[26] validation_0-auc:0.85114 validation_1-auc:0.81730
[27] validation_0-auc:0.85293 validation_1-auc:0.81843
[28] validation_0-auc:0.85335 validation_1-auc:0.81883
[29] validation_0-auc:0.85441 validation_1-auc:0.82072
[30] validation_0-auc:0.85604 validation_1-auc:0.82169
[31] validation_0-auc:0.85755 validation_1-auc:0.82256
[32] validation_0-auc:0.85830 validation_1-auc:0.82267
[33] validation_0-auc:0.85895 validation_1-auc:0.82321
[34] validation_0-auc:0.85964 validation_1-auc:0.82332
[35] validation_0-auc:0.85988 validation_1-auc:0.82408
[36] validation_0-auc:0.86037 validation_1-auc:0.82452
[37] validation_0-auc:0.86075 validation_1-auc:0.82466
[38] validation_0-auc:0.86171 validation_1-auc:0.82554
[39] validation_0-auc:0.86241 validation_1-auc:0.82582
[40] validation_0-auc:0.86277 validation_1-auc:0.82590
[41] validation_0-auc:0.86347 validation_1-auc:0.82628
[42] validation_0-auc:0.86401 validation_1-auc:0.82668
[43] validation_0-auc:0.86473 validation_1-auc:0.82686
[44] validation_0-auc:0.86523 validation_1-auc:0.82692
[45] validation_0-auc:0.86622 validation_1-auc:0.82774
[46] validation_0-auc:0.86721 validation_1-auc:0.82834
[47] validation_0-auc:0.86776 validation_1-auc:0.82831
[48] validation_0-auc:0.86837 validation_1-auc:0.82909
[49] validation_0-auc:0.86916 validation_1-auc:0.82997
[50] validation_0-auc:0.86939 validation_1-auc:0.82984
[51] validation_0-auc:0.87005 validation_1-auc:0.83006
[52] validation_0-auc:0.87057 validation_1-auc:0.83039
[53] validation_0-auc:0.87088 validation_1-auc:0.83066
[54] validation_0-auc:0.87145 validation_1-auc:0.83105
[55] validation_0-auc:0.87222 validation_1-auc:0.83139
[56] validation_0-auc:0.87285 validation_1-auc:0.83167
[57] validation_0-auc:0.87329 validation_1-auc:0.83175
[58] validation_0-auc:0.87375 validation_1-auc:0.83170
[59] validation_0-auc:0.87419 validation_1-auc:0.83196
[60] validation_0-auc:0.87465 validation_1-auc:0.83191
[61] validation_0-auc:0.87537 validation_1-auc:0.83186
[62] validation_0-auc:0.87612 validation_1-auc:0.83187
[63] validation_0-auc:0.87663 validation_1-auc:0.83192
[64] validation_0-auc:0.87751 validation_1-auc:0.83194
[65] validation_0-auc:0.87825 validation_1-auc:0.83218
[66] validation_0-auc:0.87880 validation_1-auc:0.83248
[67] validation_0-auc:0.87937 validation_1-auc:0.83241
[68] validation_0-auc:0.87980 validation_1-auc:0.83239
[69] validation_0-auc:0.88044 validation_1-auc:0.83249
[70] validation_0-auc:0.88105 validation_1-auc:0.83229
[71] validation_0-auc:0.88163 validation_1-auc:0.83261
[72] validation_0-auc:0.88212 validation_1-auc:0.83251
[73] validation_0-auc:0.88263 validation_1-auc:0.83257
[74] validation_0-auc:0.88311 validation_1-auc:0.83258
[75] validation_0-auc:0.88357 validation_1-auc:0.83255
[76] validation_0-auc:0.88411 validation_1-auc:0.83264
[77] validation_0-auc:0.88458 validation_1-auc:0.83241
[78] validation_0-auc:0.88490 validation_1-auc:0.83232
[79] validation_0-auc:0.88539 validation_1-auc:0.83249
[80] validation_0-auc:0.88609 validation_1-auc:0.83251
[81] validation_0-auc:0.88656 validation_1-auc:0.83261
[82] validation_0-auc:0.88688 validation_1-auc:0.83265
[83] validation_0-auc:0.88735 validation_1-auc:0.83258
[84] validation_0-auc:0.88770 validation_1-auc:0.83270
[85] validation_0-auc:0.88825 validation_1-auc:0.83251
[86] validation_0-auc:0.88890 validation_1-auc:0.83240
[87] validation_0-auc:0.88928 validation_1-auc:0.83229
[88] validation_0-auc:0.88958 validation_1-auc:0.83225
[89] validation_0-auc:0.88997 validation_1-auc:0.83205
[90] validation_0-auc:0.89030 validation_1-auc:0.83210
[91] validation_0-auc:0.89067 validation_1-auc:0.83223
[92] validation_0-auc:0.89098 validation_1-auc:0.83212
[93] validation_0-auc:0.89125 validation_1-auc:0.83198
[94] validation_0-auc:0.89166 validation_1-auc:0.83198
[95] validation_0-auc:0.89191 validation_1-auc:0.83196
[96] validation_0-auc:0.89219 validation_1-auc:0.83181
[97] validation_0-auc:0.89253 validation_1-auc:0.83184
[98] validation_0-auc:0.89286 validation_1-auc:0.83176
[99] validation_0-auc:0.89310 validation_1-auc:0.83184
[100] validation_0-auc:0.89337 validation_1-auc:0.83176
[101] validation_0-auc:0.89375 validation_1-auc:0.83168
[102] validation_0-auc:0.89392 validation_1-auc:0.83175
[103] validation_0-auc:0.89418 validation_1-auc:0.83162
[104] validation_0-auc:0.89446 validation_1-auc:0.83162
[105] validation_0-auc:0.89483 validation_1-auc:0.83173
[106] validation_0-auc:0.89530 validation_1-auc:0.83182
[107] validation_0-auc:0.89551 validation_1-auc:0.83179
[108] validation_0-auc:0.89576 validation_1-auc:0.83190
[109] validation_0-auc:0.89621 validation_1-auc:0.83190
[110] validation_0-auc:0.89631 validation_1-auc:0.83198
[111] validation_0-auc:0.89645 validation_1-auc:0.83200
[112] validation_0-auc:0.89657 validation_1-auc:0.83214
[113] validation_0-auc:0.89690 validation_1-auc:0.83221
[114] validation_0-auc:0.89720 validation_1-auc:0.83223
[115] validation_0-auc:0.89735 validation_1-auc:0.83218
[116] validation_0-auc:0.89769 validation_1-auc:0.83236
[117] validation_0-auc:0.89799 validation_1-auc:0.83245
[118] validation_0-auc:0.89851 validation_1-auc:0.83253
[119] validation_0-auc:0.89868 validation_1-auc:0.83256
[120] validation_0-auc:0.89875 validation_1-auc:0.83269
[121] validation_0-auc:0.89930 validation_1-auc:0.83258
[122] validation_0-auc:0.89939 validation_1-auc:0.83269
[123] validation_0-auc:0.89989 validation_1-auc:0.83274
[124] validation_0-auc:0.90029 validation_1-auc:0.83308
[125] validation_0-auc:0.90067 validation_1-auc:0.83324
[126] validation_0-auc:0.90115 validation_1-auc:0.83325
[127] validation_0-auc:0.90124 validation_1-auc:0.83324
[128] validation_0-auc:0.90147 validation_1-auc:0.83324
[129] validation_0-auc:0.90185 validation_1-auc:0.83337
[130] validation_0-auc:0.90211 validation_1-auc:0.83336
[131] validation_0-auc:0.90236 validation_1-auc:0.83341
[132] validation_0-auc:0.90266 validation_1-auc:0.83351
[133] validation_0-auc:0.90290 validation_1-auc:0.83346
[134] validation_0-auc:0.90299 validation_1-auc:0.83362
[135] validation_0-auc:0.90310 validation_1-auc:0.83363
[136] validation_0-auc:0.90327 validation_1-auc:0.83355
[137] validation_0-auc:0.90330 validation_1-auc:0.83354
[138] validation_0-auc:0.90337 validation_1-auc:0.83357
[139] validation_0-auc:0.90353 validation_1-auc:0.83348
[140] validation_0-auc:0.90363 validation_1-auc:0.83353
[141] validation_0-auc:0.90371 validation_1-auc:0.83346
[142] validation_0-auc:0.90395 validation_1-auc:0.83341
[143] validation_0-auc:0.90397 validation_1-auc:0.83341
[144] validation_0-auc:0.90422 validation_1-auc:0.83340
[145] validation_0-auc:0.90446 validation_1-auc:0.83335
[146] validation_0-auc:0.90467 validation_1-auc:0.83354
[147] validation_0-auc:0.90482 validation_1-auc:0.83355
[148] validation_0-auc:0.90484 validation_1-auc:0.83356
[149] validation_0-auc:0.90500 validation_1-auc:0.83353
[150] validation_0-auc:0.90513 validation_1-auc:0.83352
[151] validation_0-auc:0.90531 validation_1-auc:0.83349
[152] validation_0-auc:0.90548 validation_1-auc:0.83349
[153] validation_0-auc:0.90551 validation_1-auc:0.83351
[154] validation_0-auc:0.90563 validation_1-auc:0.83348
[155] validation_0-auc:0.90572 validation_1-auc:0.83343
[156] validation_0-auc:0.90579 validation_1-auc:0.83344
[157] validation_0-auc:0.90591 validation_1-auc:0.83345
[158] validation_0-auc:0.90621 validation_1-auc:0.83355
[159] validation_0-auc:0.90630 validation_1-auc:0.83359
[160] validation_0-auc:0.90643 validation_1-auc:0.83356
[161] validation_0-auc:0.90657 validation_1-auc:0.83352
[162] validation_0-auc:0.90672 validation_1-auc:0.83346
[163] validation_0-auc:0.90694 validation_1-auc:0.83347
[164] validation_0-auc:0.90714 validation_1-auc:0.83343
[165] validation_0-auc:0.90720 validation_1-auc:0.83343
[166] validation_0-auc:0.90727 validation_1-auc:0.83337
[167] validation_0-auc:0.90730 validation_1-auc:0.83335
[168] validation_0-auc:0.90737 validation_1-auc:0.83333
[169] validation_0-auc:0.90741 validation_1-auc:0.83337
[170] validation_0-auc:0.90772 validation_1-auc:0.83335
[171] validation_0-auc:0.90778 validation_1-auc:0.83332
[172] validation_0-auc:0.90781 validation_1-auc:0.83337
[173] validation_0-auc:0.90786 validation_1-auc:0.83337
[174] validation_0-auc:0.90797 validation_1-auc:0.83326
[175] validation_0-auc:0.90802 validation_1-auc:0.83334
[176] validation_0-auc:0.90818 validation_1-auc:0.83335
[177] validation_0-auc:0.90832 validation_1-auc:0.83330
[178] validation_0-auc:0.90836 validation_1-auc:0.83332
[179] validation_0-auc:0.90850 validation_1-auc:0.83339
[180] validation_0-auc:0.90856 validation_1-auc:0.83340
[181] validation_0-auc:0.90864 validation_1-auc:0.83338
[182] validation_0-auc:0.90890 validation_1-auc:0.83335
[183] validation_0-auc:0.90899 validation_1-auc:0.83327
[184] validation_0-auc:0.90902 validation_1-auc:0.83330
[185] validation_0-auc:0.90913 validation_1-auc:0.83330
[186] validation_0-auc:0.90934 validation_1-auc:0.83352
[187] validation_0-auc:0.90939 validation_1-auc:0.83356
[188] validation_0-auc:0.90947 validation_1-auc:0.83346
[189] validation_0-auc:0.90955 validation_1-auc:0.83347
[190] validation_0-auc:0.90978 validation_1-auc:0.83341
[191] validation_0-auc:0.90982 validation_1-auc:0.83340
[192] validation_0-auc:0.90987 validation_1-auc:0.83342
[193] validation_0-auc:0.90999 validation_1-auc:0.83339
[194] validation_0-auc:0.91010 validation_1-auc:0.83338
[195] validation_0-auc:0.91015 validation_1-auc:0.83329
[196] validation_0-auc:0.91018 validation_1-auc:0.83331
[197] validation_0-auc:0.91021 validation_1-auc:0.83333
[198] validation_0-auc:0.91029 validation_1-auc:0.83338
[199] validation_0-auc:0.91050 validation_1-auc:0.83338
[200] validation_0-auc:0.91056 validation_1-auc:0.83338
[201] validation_0-auc:0.91061 validation_1-auc:0.83334
[202] validation_0-auc:0.91065 validation_1-auc:0.83333
[203] validation_0-auc:0.91068 validation_1-auc:0.83334
[204] validation_0-auc:0.91079 validation_1-auc:0.83336
[205] validation_0-auc:0.91091 validation_1-auc:0.83323
[206] validation_0-auc:0.91098 validation_1-auc:0.83316
[207] validation_0-auc:0.91119 validation_1-auc:0.83319
[208] validation_0-auc:0.91131 validation_1-auc:0.83320
[209] validation_0-auc:0.91160 validation_1-auc:0.83331
[210] validation_0-auc:0.91165 validation_1-auc:0.83329
[211] validation_0-auc:0.91182 validation_1-auc:0.83325
[212] validation_0-auc:0.91189 validation_1-auc:0.83330
[213] validation_0-auc:0.91198 validation_1-auc:0.83329
[214] validation_0-auc:0.91219 validation_1-auc:0.83321
[215] validation_0-auc:0.91225 validation_1-auc:0.83326
[216] validation_0-auc:0.91228 validation_1-auc:0.83329
[217] validation_0-auc:0.91234 validation_1-auc:0.83326
[218] validation_0-auc:0.91258 validation_1-auc:0.83311
[219] validation_0-auc:0.91269 validation_1-auc:0.83311
[220] validation_0-auc:0.91279 validation_1-auc:0.83305
[221] validation_0-auc:0.91301 validation_1-auc:0.83310
[222] validation_0-auc:0.91311 validation_1-auc:0.83307
[223] validation_0-auc:0.91318 validation_1-auc:0.83310
[224] validation_0-auc:0.91333 validation_1-auc:0.83308
[225] validation_0-auc:0.91340 validation_1-auc:0.83309
[226] validation_0-auc:0.91340 validation_1-auc:0.83307
[227] validation_0-auc:0.91357 validation_1-auc:0.83310
[228] validation_0-auc:0.91360 validation_1-auc:0.83310
[229] validation_0-auc:0.91367 validation_1-auc:0.83312
[230] validation_0-auc:0.91378 validation_1-auc:0.83313
[231] validation_0-auc:0.91382 validation_1-auc:0.83313
[232] validation_0-auc:0.91383 validation_1-auc:0.83314
[233] validation_0-auc:0.91405 validation_1-auc:0.83312
[234] validation_0-auc:0.91414 validation_1-auc:0.83307
ROC AUC: 0.8429
- HyperOpt를 이용해 베이지안 최적화 기법으로 튜닝하기
In [21]:
#검색 공간 설정
from hyperopt import hp
xgb_search_space = {'max_depth':hp.quniform('max_depth', 5, 15, 1),
'min_child_weight':hp.quniform('min_child_weight', 1, 6, 1),
'colsample_bytree':hp.uniform('colsample_bytree', 0.5, 0.95),
'learning_rate':hp.uniform('learning_rate', 0.01, 0.2)}
In [27]:
#목적함수 만들기
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
def objective_func(search_space):
xgb_clf = XGBClassifier(n_estimators=100, max_depth=int(search_space['max_depth']),
min_child_weight=int(search_space['min_child_weight']),
learning_rate=search_space['learning_rate'],
colsample_bytree=search_space['colsample_bytree'])
roc_auc_list = []
#3개 kfold 방식 적용
kf = KFold(n_splits=3)
#X_train을 다시 학습용, 검증용으로 나누기
for tr_index, val_index in kf.split(X_train):
#kf.split(X_train)으로 추출된 학습과 검증 index값으로 학습, 검증데이터셋 분리
X_tr, y_tr = X_train.iloc[tr_index], y_train.iloc[tr_index]
X_val, y_val = X_train.iloc[val_index], y_train.iloc[val_index]
#조기중단 30회, 추출된 학습, 검증데이터로 학습 수행
xgb_clf.fit(X_tr, y_tr, early_stopping_rounds=30, eval_metric='auc', eval_set=[(X_tr, y_tr), (X_val, y_val)])
#1로 예측한 확률값 추출 후 roe auc 계산하고 평균 계산을 위해 리스트에 담기
score = roc_auc_score(y_val, xgb_clf.predict_proba(X_val)[:, 1])
roc_auc_list.append(score)
return -1 * np.mean( roc_auc_list)
In [36]:
#fmin()함수로 최적 하이퍼 파라미터 도출하기
from hyperopt import fmin, tpe, Trials
trials = Trials()
best = fmin(fn=objective_func,
space=xgb_search_space,
algo=tpe.suggest,
max_evals=50,
trials=trials,
rstate=np.random.default_rng(seed=30))
print('best:', best)
In [ ]:
#도출된 하이퍼 파라미터 기반으로 재학습, 예측 수행, ROC AUC 측정
xgb_clf = XGBClassifier(n_estimators=500, max_depth=int(best['max_depth']),
min_child_weight=int(best['min_child_weight']),
learning_rate=round(best['learning_rate'], 5),
colsample_bytree=round(best['colsample_bytree'], 5)
xgb_clf.fit(X_tr, y_tr, early_stopping_rounds=100, eval_metric='auc', eval_set=[(X_tr, y_tr), (X_val, y_val)])
xgb_roc_score = roc_auc_score(y_test, xgb_clf.predict_proba(X_test)[:, 1])
print('ROC AUC: {0:.4f}'.format(xgb_roc_score))
- 하이퍼 파라미터 튜닝 후 roc auc가 개선됨
In [30]:
import matplotlib.pyplot as plt
from xgboost import plot_importance
fig, ax = plt.subplots(1, 1, figsize=(10,8))
plot_importance(xgb_clf, ax=ax, max_num_features=20, height=0.4)
Out[30]:
<Axes: title={'center': 'Feature importance'}, xlabel='F score', ylabel='Features'>
LightGBM 모델 학습, 하이퍼 파라미터 튜닝
In [31]:
from lightgbm import LGBMClassifier
lgbm_clf = LGBMClassifier(n_estimators=500)
eval_set = [(X_tr, y_tr), (X_val, y_val)]
lgbm_clf.fit(X_tr, y_tr, early_stopping_rounds=100, eval_metric='auc', eval_set=eval_set)
lgbm_roc_score = roc_auc_score(y_test, lgbm_clf.predict_proba(X_test)[:, 1])
print('ROC AUC: {0:.4f}'.format(lgbm_roc_score))
In [32]:
#검색공간 설정
lgbm_search_space = {'num_leaves':hp.quniform('num_leaves', 32, 64, 1),
'max_depth':hp.quniform('max_depth', 100, 160, 1),
'min_child_samples':hp.quniform('min_child_samples', 60, 100, 1),
'subsample':hp.uniform('subsample', 0.7, 1),
'learning_rate':hp.uniform('learning_rate', 0.01, 0.2)}
In [34]:
#목적함수 생성
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
def objective_func(search_space):
xgb_clf = XGBClassifier(n_estimators=100,
num_leaves=int(search_space['num_leaves']),
max_depth=int(search_space['max_depth']),
min_child_samples=int(search_space['min_child_samples']),
learning_rate=search_space['learning_rate'],
subsample=search_space['subsample'])
roc_auc_list = []
#3개 kfold 방식 적용
kf = KFold(n_splits=3)
#X_train을 다시 학습용, 검증용으로 나누기
for tr_index, val_index in kf.split(X_train):
#kf.split(X_train)으로 추출된 학습과 검증 index값으로 학습, 검증데이터셋 분리
X_tr, y_tr = X_train.iloc[tr_index], y_train.iloc[tr_index]
X_val, y_val = X_train.iloc[val_index], y_train.iloc[val_index]
#조기중단 30회, 추출된 학습, 검증데이터로 학습 수행
lgbm_clf.fit(X_tr, y_tr, early_stopping_rounds=30, eval_metric='auc', eval_set=[(X_tr, y_tr), (X_val, y_val)])
#1로 예측한 확률값 추출 후 roe auc 계산하고 평균 계산을 위해 리스트에 담기
score = roc_auc_score(y_val, lgbm_clf.predict_proba(X_val)[:, 1])
roc_auc_list.append(score)
return -1 * np.mean(roc_auc_list)
In [ ]:
from hyperopt import fmin, tpe, Trials
trials = Trials()
best = fmin(fn=objective_func,
space=lgbm_search_space,
algo=tpe.suggest,
max_evals=50,
trials=trials,
rstate=np.random.default_rng(seed=30))
print('best:', best)
In [39]:
#fmin()호출해서 최적 하이퍼 파라미터 도출
lgbm_clf = LGBMClassifier(n_estimators=500,
num_leaves=int(best['num_leaves']),
max_depth=int(best['max_depth']),
min_child_samples=int(best['min_child_samples']),
learning_rate=round(best['learning_rate'], 5),
subsample=round(best['subsample'], 5))
lgbm_clf.fit(X_tr, y_tr, early_stopping_rounds=100, eval_metric='auc', eval_set=[(X_tr, y_tr), (X_val, y_val)])
lgbm_roc_score = roc_auc_score(y_test, lgbm_clf.predict_proba(X_test)[:, 1])
print('ROC AUC: {0:.4f}'.format(lgbm_roc_score))
ROC AUC: 0.8425
'Data Science > 파이썬 머신러닝 완벽 가이드' 카테고리의 다른 글
[sklearn] (29) Stacking 스태킹 모델 (0) | 2023.06.06 |
---|---|
[sklearn] (28) Kaggle 신용카드 사기 거래 탐지, Credit Card Fraud Classification (0) | 2023.06.01 |
[sklearn] (26) 베이지안 최적화 기반의 HyperOpt (1) | 2023.05.30 |
[sklearn] (25) - LightGBM (0) | 2023.05.30 |
[sklearn] (24) 사이킷런 Wrapper XGBoost (0) | 2023.05.30 |