2.6. XGBoost#
XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable.
import numpy as np
import pandas as pd
import xgboost as xgb
np.random.seed(0)
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
%config InlineBackend.figure_format = 'retina'
# Generate a synthetic dataset
X, y = make_classification(
n_samples=50000, n_features=15, n_classes=2, n_informative=5, random_state=42
)
X, y = pd.DataFrame(X), pd.DataFrame(y)
X.columns = [f"feature_{i+1}" for i in range(len(X.columns))]
ix_train, ix_test = train_test_split(X.index, stratify=y, random_state=62)
# Create an XGBoost model with random forests
model_gbdt = xgb.XGBClassifier(
booster="gbtree",
n_estimators=1000, # trees
max_depth=10,
subsample=0.5,
gamma=0.5,
reg_alpha=20,
reg_lambda=20,
min_child_weight=10,
colsample_bynode=0.8,
num_parallel_tree=2, # trees per forest
eval_metric="logloss", # evaluation metric
early_stopping_rounds=10,
grow_policy="lossguide",
random_state=42,
)
evalset = [(X.loc[ix_train], y.loc[ix_train]), (X.loc[ix_test], y.loc[ix_test])]
# Fit the model to the generated dataset
model_gbdt.fit(X.loc[ix_train], y.loc[ix_train], eval_set=evalset)
# Make predictions using the model
predictions_trn = model_gbdt.predict_proba(X.loc[ix_train])[:, 1]
predictions_tst = model_gbdt.predict_proba(X.loc[ix_test])[:, 1]
gini_trn = roc_auc_score(y.loc[ix_train], predictions_trn) * 2 - 1
gini_tst = roc_auc_score(y.loc[ix_test], predictions_tst) * 2 - 1
print(f"Train Gini score: {gini_trn:.2%}\n" f"Test Gini score: {gini_tst:.2%}")
[0] validation_0-logloss:0.54626 validation_1-logloss:0.54642
[1] validation_0-logloss:0.45594 validation_1-logloss:0.45765
[2] validation_0-logloss:0.39858 validation_1-logloss:0.40107
[3] validation_0-logloss:0.35277 validation_1-logloss:0.35585
[4] validation_0-logloss:0.31624 validation_1-logloss:0.32011
[5] validation_0-logloss:0.29152 validation_1-logloss:0.29628
[6] validation_0-logloss:0.26528 validation_1-logloss:0.27063
[7] validation_0-logloss:0.24296 validation_1-logloss:0.24831
[8] validation_0-logloss:0.22727 validation_1-logloss:0.23330
[9] validation_0-logloss:0.21415 validation_1-logloss:0.22061
[10] validation_0-logloss:0.20207 validation_1-logloss:0.20886
[11] validation_0-logloss:0.19335 validation_1-logloss:0.20034
[12] validation_0-logloss:0.18805 validation_1-logloss:0.19538
[13] validation_0-logloss:0.18084 validation_1-logloss:0.18858
[14] validation_0-logloss:0.17538 validation_1-logloss:0.18351
[15] validation_0-logloss:0.17257 validation_1-logloss:0.18117
[16] validation_0-logloss:0.16776 validation_1-logloss:0.17659
[17] validation_0-logloss:0.16356 validation_1-logloss:0.17265
[18] validation_0-logloss:0.16064 validation_1-logloss:0.16980
[19] validation_0-logloss:0.15700 validation_1-logloss:0.16640
[20] validation_0-logloss:0.15391 validation_1-logloss:0.16389
[21] validation_0-logloss:0.15185 validation_1-logloss:0.16218
[22] validation_0-logloss:0.15053 validation_1-logloss:0.16106
[23] validation_0-logloss:0.14841 validation_1-logloss:0.15911
[24] validation_0-logloss:0.14723 validation_1-logloss:0.15813
[25] validation_0-logloss:0.14551 validation_1-logloss:0.15651
[26] validation_0-logloss:0.14437 validation_1-logloss:0.15562
[27] validation_0-logloss:0.14407 validation_1-logloss:0.15545
[28] validation_0-logloss:0.14320 validation_1-logloss:0.15454
[29] validation_0-logloss:0.14204 validation_1-logloss:0.15350
[30] validation_0-logloss:0.14118 validation_1-logloss:0.15280
[31] validation_0-logloss:0.14047 validation_1-logloss:0.15227
[32] validation_0-logloss:0.13975 validation_1-logloss:0.15167
[33] validation_0-logloss:0.13954 validation_1-logloss:0.15151
[34] validation_0-logloss:0.13935 validation_1-logloss:0.15128
[35] validation_0-logloss:0.13932 validation_1-logloss:0.15122
[36] validation_0-logloss:0.13858 validation_1-logloss:0.15055
[37] validation_0-logloss:0.13796 validation_1-logloss:0.14996
[38] validation_0-logloss:0.13742 validation_1-logloss:0.14966
[39] validation_0-logloss:0.13674 validation_1-logloss:0.14929
[40] validation_0-logloss:0.13611 validation_1-logloss:0.14868
[41] validation_0-logloss:0.13611 validation_1-logloss:0.14868
[42] validation_0-logloss:0.13522 validation_1-logloss:0.14794
[43] validation_0-logloss:0.13515 validation_1-logloss:0.14796
[44] validation_0-logloss:0.13479 validation_1-logloss:0.14759
[45] validation_0-logloss:0.13473 validation_1-logloss:0.14759
[46] validation_0-logloss:0.13425 validation_1-logloss:0.14741
[47] validation_0-logloss:0.13407 validation_1-logloss:0.14742
[48] validation_0-logloss:0.13367 validation_1-logloss:0.14724
[49] validation_0-logloss:0.13336 validation_1-logloss:0.14702
[50] validation_0-logloss:0.13311 validation_1-logloss:0.14680
[51] validation_0-logloss:0.13307 validation_1-logloss:0.14680
[52] validation_0-logloss:0.13278 validation_1-logloss:0.14661
[53] validation_0-logloss:0.13230 validation_1-logloss:0.14628
[54] validation_0-logloss:0.13222 validation_1-logloss:0.14626
[55] validation_0-logloss:0.13219 validation_1-logloss:0.14625
[56] validation_0-logloss:0.13170 validation_1-logloss:0.14569
[57] validation_0-logloss:0.13141 validation_1-logloss:0.14539
[58] validation_0-logloss:0.13141 validation_1-logloss:0.14539
[59] validation_0-logloss:0.13080 validation_1-logloss:0.14492
[60] validation_0-logloss:0.13080 validation_1-logloss:0.14492
[61] validation_0-logloss:0.13080 validation_1-logloss:0.14492
[62] validation_0-logloss:0.13011 validation_1-logloss:0.14428
[63] validation_0-logloss:0.13011 validation_1-logloss:0.14428
[64] validation_0-logloss:0.13011 validation_1-logloss:0.14428
[65] validation_0-logloss:0.13004 validation_1-logloss:0.14431
[66] validation_0-logloss:0.13004 validation_1-logloss:0.14431
[67] validation_0-logloss:0.13004 validation_1-logloss:0.14431
[68] validation_0-logloss:0.13004 validation_1-logloss:0.14431
[69] validation_0-logloss:0.12946 validation_1-logloss:0.14371
[70] validation_0-logloss:0.12939 validation_1-logloss:0.14361
[71] validation_0-logloss:0.12939 validation_1-logloss:0.14361
[72] validation_0-logloss:0.12930 validation_1-logloss:0.14364
[73] validation_0-logloss:0.12917 validation_1-logloss:0.14361
[74] validation_0-logloss:0.12907 validation_1-logloss:0.14365
[75] validation_0-logloss:0.12907 validation_1-logloss:0.14365
[76] validation_0-logloss:0.12907 validation_1-logloss:0.14365
[77] validation_0-logloss:0.12867 validation_1-logloss:0.14326
[78] validation_0-logloss:0.12824 validation_1-logloss:0.14290
[79] validation_0-logloss:0.12809 validation_1-logloss:0.14273
[80] validation_0-logloss:0.12802 validation_1-logloss:0.14270
[81] validation_0-logloss:0.12796 validation_1-logloss:0.14270
[82] validation_0-logloss:0.12781 validation_1-logloss:0.14264
[83] validation_0-logloss:0.12781 validation_1-logloss:0.14264
[84] validation_0-logloss:0.12770 validation_1-logloss:0.14262
[85] validation_0-logloss:0.12765 validation_1-logloss:0.14263
[86] validation_0-logloss:0.12745 validation_1-logloss:0.14256
[87] validation_0-logloss:0.12727 validation_1-logloss:0.14237
[88] validation_0-logloss:0.12717 validation_1-logloss:0.14238
[89] validation_0-logloss:0.12717 validation_1-logloss:0.14238
[90] validation_0-logloss:0.12710 validation_1-logloss:0.14241
[91] validation_0-logloss:0.12682 validation_1-logloss:0.14217
[92] validation_0-logloss:0.12682 validation_1-logloss:0.14217
[93] validation_0-logloss:0.12682 validation_1-logloss:0.14217
[94] validation_0-logloss:0.12675 validation_1-logloss:0.14215
[95] validation_0-logloss:0.12675 validation_1-logloss:0.14215
[96] validation_0-logloss:0.12660 validation_1-logloss:0.14212
[97] validation_0-logloss:0.12660 validation_1-logloss:0.14212
[98] validation_0-logloss:0.12660 validation_1-logloss:0.14212
[99] validation_0-logloss:0.12625 validation_1-logloss:0.14187
[100] validation_0-logloss:0.12599 validation_1-logloss:0.14169
[101] validation_0-logloss:0.12563 validation_1-logloss:0.14133
[102] validation_0-logloss:0.12552 validation_1-logloss:0.14130
[103] validation_0-logloss:0.12552 validation_1-logloss:0.14130
[104] validation_0-logloss:0.12552 validation_1-logloss:0.14130
[105] validation_0-logloss:0.12532 validation_1-logloss:0.14121
[106] validation_0-logloss:0.12527 validation_1-logloss:0.14122
[107] validation_0-logloss:0.12527 validation_1-logloss:0.14122
[108] validation_0-logloss:0.12527 validation_1-logloss:0.14122
[109] validation_0-logloss:0.12517 validation_1-logloss:0.14126
[110] validation_0-logloss:0.12517 validation_1-logloss:0.14126
[111] validation_0-logloss:0.12509 validation_1-logloss:0.14124
[112] validation_0-logloss:0.12470 validation_1-logloss:0.14092
[113] validation_0-logloss:0.12470 validation_1-logloss:0.14092
[114] validation_0-logloss:0.12470 validation_1-logloss:0.14092
[115] validation_0-logloss:0.12457 validation_1-logloss:0.14072
[116] validation_0-logloss:0.12447 validation_1-logloss:0.14066
[117] validation_0-logloss:0.12434 validation_1-logloss:0.14052
[118] validation_0-logloss:0.12430 validation_1-logloss:0.14057
[119] validation_0-logloss:0.12412 validation_1-logloss:0.14045
[120] validation_0-logloss:0.12401 validation_1-logloss:0.14042
[121] validation_0-logloss:0.12401 validation_1-logloss:0.14042
[122] validation_0-logloss:0.12401 validation_1-logloss:0.14042
[123] validation_0-logloss:0.12401 validation_1-logloss:0.14042
[124] validation_0-logloss:0.12398 validation_1-logloss:0.14041
[125] validation_0-logloss:0.12384 validation_1-logloss:0.14038
[126] validation_0-logloss:0.12379 validation_1-logloss:0.14038
[127] validation_0-logloss:0.12379 validation_1-logloss:0.14038
[128] validation_0-logloss:0.12379 validation_1-logloss:0.14038
[129] validation_0-logloss:0.12379 validation_1-logloss:0.14038
[130] validation_0-logloss:0.12379 validation_1-logloss:0.14038
[131] validation_0-logloss:0.12341 validation_1-logloss:0.14001
[132] validation_0-logloss:0.12341 validation_1-logloss:0.14001
[133] validation_0-logloss:0.12341 validation_1-logloss:0.14001
[134] validation_0-logloss:0.12336 validation_1-logloss:0.14000
[135] validation_0-logloss:0.12315 validation_1-logloss:0.13985
[136] validation_0-logloss:0.12315 validation_1-logloss:0.13985
[137] validation_0-logloss:0.12281 validation_1-logloss:0.13958
[138] validation_0-logloss:0.12281 validation_1-logloss:0.13958
[139] validation_0-logloss:0.12281 validation_1-logloss:0.13958
[140] validation_0-logloss:0.12281 validation_1-logloss:0.13958
[141] validation_0-logloss:0.12270 validation_1-logloss:0.13948
[142] validation_0-logloss:0.12262 validation_1-logloss:0.13947
[143] validation_0-logloss:0.12247 validation_1-logloss:0.13936
[144] validation_0-logloss:0.12247 validation_1-logloss:0.13936
[145] validation_0-logloss:0.12243 validation_1-logloss:0.13934
[146] validation_0-logloss:0.12237 validation_1-logloss:0.13936
[147] validation_0-logloss:0.12237 validation_1-logloss:0.13936
[148] validation_0-logloss:0.12237 validation_1-logloss:0.13936
[149] validation_0-logloss:0.12232 validation_1-logloss:0.13933
[150] validation_0-logloss:0.12232 validation_1-logloss:0.13933
[151] validation_0-logloss:0.12226 validation_1-logloss:0.13934
[152] validation_0-logloss:0.12226 validation_1-logloss:0.13934
[153] validation_0-logloss:0.12201 validation_1-logloss:0.13921
[154] validation_0-logloss:0.12187 validation_1-logloss:0.13920
[155] validation_0-logloss:0.12181 validation_1-logloss:0.13921
[156] validation_0-logloss:0.12177 validation_1-logloss:0.13923
[157] validation_0-logloss:0.12177 validation_1-logloss:0.13923
[158] validation_0-logloss:0.12149 validation_1-logloss:0.13904
[159] validation_0-logloss:0.12149 validation_1-logloss:0.13904
[160] validation_0-logloss:0.12131 validation_1-logloss:0.13888
[161] validation_0-logloss:0.12115 validation_1-logloss:0.13898
[162] validation_0-logloss:0.12099 validation_1-logloss:0.13895
[163] validation_0-logloss:0.12099 validation_1-logloss:0.13895
[164] validation_0-logloss:0.12095 validation_1-logloss:0.13893
[165] validation_0-logloss:0.12091 validation_1-logloss:0.13892
[166] validation_0-logloss:0.12091 validation_1-logloss:0.13892
[167] validation_0-logloss:0.12089 validation_1-logloss:0.13891
[168] validation_0-logloss:0.12089 validation_1-logloss:0.13891
[169] validation_0-logloss:0.12089 validation_1-logloss:0.13891
[170] validation_0-logloss:0.12089 validation_1-logloss:0.13891
Train Gini score: 98.15%
Test Gini score: 97.30%
# check how many trees are trained
booster_ = model_gbdt.get_booster()
num_trees = booster_.best_iteration + 1
print(f"Total number of trees in the final model: {num_trees}")
# get a dump
booster_dump = booster_.get_dump()
num_trees_dump = len(booster_dump)
print(f"Total number of trees used in estimation: {num_trees_dump}")
Total number of trees in the final model: 161
Total number of trees used in estimation: 342
# get margins (log odds, leaf weights)
booster_.trees_to_dataframe()[booster_.trees_to_dataframe().Feature.isin(["Leaf"])][
"Gain"
][:10]
9 -0.225236
10 0.190313
20 -0.143501
21 -0.134992
27 -0.000000
29 -0.222734
31 0.010724
32 -0.110243
33 0.161738
34 -0.000000
Name: Gain, dtype: float64
Loss Curve
PythonGuides.com: Matplotlib Update Plot in Loop
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
%matplotlib widget
# Import the loss curve
results = model_gbdt.evals_result()
loss_values = results["validation_1"]["logloss"]
# Create figure and subplot
fig, ax = plt.subplots(figsize=(6, 4))
x_plot = []
y_plot = []
# Plot
# (plot_1,) = ax.plot(x, y)
(plot_1,) = ax.plot(x_plot, y_plot, color="#a7fe01", label="Log Loss")
plt.axis([0, len(loss_values), min(loss_values), max(loss_values)])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Number of Trees")
ax.set_ylabel("Log Loss")
plt.title("Learning Curve")
plt.tight_layout()
# Animation Function
def animate_plot(i):
x_plot = np.arange(i)
y_plot = loss_values[:i]
plot_1.set_data(x_plot, y_plot)
return (plot_1,)
# Animated Function
ani = FuncAnimation(fig, animate_plot, frames=len(loss_values), interval=100)
# Save as gif
# path_to_export = "/Users/deburky/Documents/python/risk-practitioner-ebook/risk-practitioner/images"
# ani.save(f"{path_to_export}/learning_curve.gif", writer="pillow", fps=10, dpi=600)
# ani.save(f"learning_curve.gif", writer="pillow", fps=10, dpi=600)