2.6. XGBoost#

XGBoost is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable.
import numpy as np
import pandas as pd
import xgboost as xgb

np.random.seed(0)

from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

%config InlineBackend.figure_format = 'retina'
# Generate a synthetic dataset
X, y = make_classification(
    n_samples=50000, n_features=15, n_classes=2, n_informative=5, random_state=42
)

X, y = pd.DataFrame(X), pd.DataFrame(y)

X.columns = [f"feature_{i+1}" for i in range(len(X.columns))]

ix_train, ix_test = train_test_split(X.index, stratify=y, random_state=62)
# Create an XGBoost model with random forests
model_gbdt = xgb.XGBClassifier(
    booster="gbtree",
    n_estimators=1000,  # trees
    max_depth=10,
    subsample=0.5,
    gamma=0.5,
    reg_alpha=20,
    reg_lambda=20,
    min_child_weight=10,
    colsample_bynode=0.8,
    num_parallel_tree=2,  # trees per forest
    eval_metric="logloss",  # evaluation metric
    early_stopping_rounds=10,
    grow_policy="lossguide",
    random_state=42,
)

evalset = [(X.loc[ix_train], y.loc[ix_train]), (X.loc[ix_test], y.loc[ix_test])]

# Fit the model to the generated dataset
model_gbdt.fit(X.loc[ix_train], y.loc[ix_train], eval_set=evalset)

# Make predictions using the model
predictions_trn = model_gbdt.predict_proba(X.loc[ix_train])[:, 1]
predictions_tst = model_gbdt.predict_proba(X.loc[ix_test])[:, 1]

gini_trn = roc_auc_score(y.loc[ix_train], predictions_trn) * 2 - 1
gini_tst = roc_auc_score(y.loc[ix_test], predictions_tst) * 2 - 1

print(f"Train Gini score: {gini_trn:.2%}\n" f"Test Gini score: {gini_tst:.2%}")
[0]	validation_0-logloss:0.54626	validation_1-logloss:0.54642
[1]	validation_0-logloss:0.45594	validation_1-logloss:0.45765
[2]	validation_0-logloss:0.39858	validation_1-logloss:0.40107
[3]	validation_0-logloss:0.35277	validation_1-logloss:0.35585
[4]	validation_0-logloss:0.31624	validation_1-logloss:0.32011
[5]	validation_0-logloss:0.29152	validation_1-logloss:0.29628
[6]	validation_0-logloss:0.26528	validation_1-logloss:0.27063
[7]	validation_0-logloss:0.24296	validation_1-logloss:0.24831
[8]	validation_0-logloss:0.22727	validation_1-logloss:0.23330
[9]	validation_0-logloss:0.21415	validation_1-logloss:0.22061
[10]	validation_0-logloss:0.20207	validation_1-logloss:0.20886
[11]	validation_0-logloss:0.19335	validation_1-logloss:0.20034
[12]	validation_0-logloss:0.18805	validation_1-logloss:0.19538
[13]	validation_0-logloss:0.18084	validation_1-logloss:0.18858
[14]	validation_0-logloss:0.17538	validation_1-logloss:0.18351
[15]	validation_0-logloss:0.17257	validation_1-logloss:0.18117
[16]	validation_0-logloss:0.16776	validation_1-logloss:0.17659
[17]	validation_0-logloss:0.16356	validation_1-logloss:0.17265
[18]	validation_0-logloss:0.16064	validation_1-logloss:0.16980
[19]	validation_0-logloss:0.15700	validation_1-logloss:0.16640
[20]	validation_0-logloss:0.15391	validation_1-logloss:0.16389
[21]	validation_0-logloss:0.15185	validation_1-logloss:0.16218
[22]	validation_0-logloss:0.15053	validation_1-logloss:0.16106
[23]	validation_0-logloss:0.14841	validation_1-logloss:0.15911
[24]	validation_0-logloss:0.14723	validation_1-logloss:0.15813
[25]	validation_0-logloss:0.14551	validation_1-logloss:0.15651
[26]	validation_0-logloss:0.14437	validation_1-logloss:0.15562
[27]	validation_0-logloss:0.14407	validation_1-logloss:0.15545
[28]	validation_0-logloss:0.14320	validation_1-logloss:0.15454
[29]	validation_0-logloss:0.14204	validation_1-logloss:0.15350
[30]	validation_0-logloss:0.14118	validation_1-logloss:0.15280
[31]	validation_0-logloss:0.14047	validation_1-logloss:0.15227
[32]	validation_0-logloss:0.13975	validation_1-logloss:0.15167
[33]	validation_0-logloss:0.13954	validation_1-logloss:0.15151
[34]	validation_0-logloss:0.13935	validation_1-logloss:0.15128
[35]	validation_0-logloss:0.13932	validation_1-logloss:0.15122
[36]	validation_0-logloss:0.13858	validation_1-logloss:0.15055
[37]	validation_0-logloss:0.13796	validation_1-logloss:0.14996
[38]	validation_0-logloss:0.13742	validation_1-logloss:0.14966
[39]	validation_0-logloss:0.13674	validation_1-logloss:0.14929
[40]	validation_0-logloss:0.13611	validation_1-logloss:0.14868
[41]	validation_0-logloss:0.13611	validation_1-logloss:0.14868
[42]	validation_0-logloss:0.13522	validation_1-logloss:0.14794
[43]	validation_0-logloss:0.13515	validation_1-logloss:0.14796
[44]	validation_0-logloss:0.13479	validation_1-logloss:0.14759
[45]	validation_0-logloss:0.13473	validation_1-logloss:0.14759
[46]	validation_0-logloss:0.13425	validation_1-logloss:0.14741
[47]	validation_0-logloss:0.13407	validation_1-logloss:0.14742
[48]	validation_0-logloss:0.13367	validation_1-logloss:0.14724
[49]	validation_0-logloss:0.13336	validation_1-logloss:0.14702
[50]	validation_0-logloss:0.13311	validation_1-logloss:0.14680
[51]	validation_0-logloss:0.13307	validation_1-logloss:0.14680
[52]	validation_0-logloss:0.13278	validation_1-logloss:0.14661
[53]	validation_0-logloss:0.13230	validation_1-logloss:0.14628
[54]	validation_0-logloss:0.13222	validation_1-logloss:0.14626
[55]	validation_0-logloss:0.13219	validation_1-logloss:0.14625
[56]	validation_0-logloss:0.13170	validation_1-logloss:0.14569
[57]	validation_0-logloss:0.13141	validation_1-logloss:0.14539
[58]	validation_0-logloss:0.13141	validation_1-logloss:0.14539
[59]	validation_0-logloss:0.13080	validation_1-logloss:0.14492
[60]	validation_0-logloss:0.13080	validation_1-logloss:0.14492
[61]	validation_0-logloss:0.13080	validation_1-logloss:0.14492
[62]	validation_0-logloss:0.13011	validation_1-logloss:0.14428
[63]	validation_0-logloss:0.13011	validation_1-logloss:0.14428
[64]	validation_0-logloss:0.13011	validation_1-logloss:0.14428
[65]	validation_0-logloss:0.13004	validation_1-logloss:0.14431
[66]	validation_0-logloss:0.13004	validation_1-logloss:0.14431
[67]	validation_0-logloss:0.13004	validation_1-logloss:0.14431
[68]	validation_0-logloss:0.13004	validation_1-logloss:0.14431
[69]	validation_0-logloss:0.12946	validation_1-logloss:0.14371
[70]	validation_0-logloss:0.12939	validation_1-logloss:0.14361
[71]	validation_0-logloss:0.12939	validation_1-logloss:0.14361
[72]	validation_0-logloss:0.12930	validation_1-logloss:0.14364
[73]	validation_0-logloss:0.12917	validation_1-logloss:0.14361
[74]	validation_0-logloss:0.12907	validation_1-logloss:0.14365
[75]	validation_0-logloss:0.12907	validation_1-logloss:0.14365
[76]	validation_0-logloss:0.12907	validation_1-logloss:0.14365
[77]	validation_0-logloss:0.12867	validation_1-logloss:0.14326
[78]	validation_0-logloss:0.12824	validation_1-logloss:0.14290
[79]	validation_0-logloss:0.12809	validation_1-logloss:0.14273
[80]	validation_0-logloss:0.12802	validation_1-logloss:0.14270
[81]	validation_0-logloss:0.12796	validation_1-logloss:0.14270
[82]	validation_0-logloss:0.12781	validation_1-logloss:0.14264
[83]	validation_0-logloss:0.12781	validation_1-logloss:0.14264
[84]	validation_0-logloss:0.12770	validation_1-logloss:0.14262
[85]	validation_0-logloss:0.12765	validation_1-logloss:0.14263
[86]	validation_0-logloss:0.12745	validation_1-logloss:0.14256
[87]	validation_0-logloss:0.12727	validation_1-logloss:0.14237
[88]	validation_0-logloss:0.12717	validation_1-logloss:0.14238
[89]	validation_0-logloss:0.12717	validation_1-logloss:0.14238
[90]	validation_0-logloss:0.12710	validation_1-logloss:0.14241
[91]	validation_0-logloss:0.12682	validation_1-logloss:0.14217
[92]	validation_0-logloss:0.12682	validation_1-logloss:0.14217
[93]	validation_0-logloss:0.12682	validation_1-logloss:0.14217
[94]	validation_0-logloss:0.12675	validation_1-logloss:0.14215
[95]	validation_0-logloss:0.12675	validation_1-logloss:0.14215
[96]	validation_0-logloss:0.12660	validation_1-logloss:0.14212
[97]	validation_0-logloss:0.12660	validation_1-logloss:0.14212
[98]	validation_0-logloss:0.12660	validation_1-logloss:0.14212
[99]	validation_0-logloss:0.12625	validation_1-logloss:0.14187
[100]	validation_0-logloss:0.12599	validation_1-logloss:0.14169
[101]	validation_0-logloss:0.12563	validation_1-logloss:0.14133
[102]	validation_0-logloss:0.12552	validation_1-logloss:0.14130
[103]	validation_0-logloss:0.12552	validation_1-logloss:0.14130
[104]	validation_0-logloss:0.12552	validation_1-logloss:0.14130
[105]	validation_0-logloss:0.12532	validation_1-logloss:0.14121
[106]	validation_0-logloss:0.12527	validation_1-logloss:0.14122
[107]	validation_0-logloss:0.12527	validation_1-logloss:0.14122
[108]	validation_0-logloss:0.12527	validation_1-logloss:0.14122
[109]	validation_0-logloss:0.12517	validation_1-logloss:0.14126
[110]	validation_0-logloss:0.12517	validation_1-logloss:0.14126
[111]	validation_0-logloss:0.12509	validation_1-logloss:0.14124
[112]	validation_0-logloss:0.12470	validation_1-logloss:0.14092
[113]	validation_0-logloss:0.12470	validation_1-logloss:0.14092
[114]	validation_0-logloss:0.12470	validation_1-logloss:0.14092
[115]	validation_0-logloss:0.12457	validation_1-logloss:0.14072
[116]	validation_0-logloss:0.12447	validation_1-logloss:0.14066
[117]	validation_0-logloss:0.12434	validation_1-logloss:0.14052
[118]	validation_0-logloss:0.12430	validation_1-logloss:0.14057
[119]	validation_0-logloss:0.12412	validation_1-logloss:0.14045
[120]	validation_0-logloss:0.12401	validation_1-logloss:0.14042
[121]	validation_0-logloss:0.12401	validation_1-logloss:0.14042
[122]	validation_0-logloss:0.12401	validation_1-logloss:0.14042
[123]	validation_0-logloss:0.12401	validation_1-logloss:0.14042
[124]	validation_0-logloss:0.12398	validation_1-logloss:0.14041
[125]	validation_0-logloss:0.12384	validation_1-logloss:0.14038
[126]	validation_0-logloss:0.12379	validation_1-logloss:0.14038
[127]	validation_0-logloss:0.12379	validation_1-logloss:0.14038
[128]	validation_0-logloss:0.12379	validation_1-logloss:0.14038
[129]	validation_0-logloss:0.12379	validation_1-logloss:0.14038
[130]	validation_0-logloss:0.12379	validation_1-logloss:0.14038
[131]	validation_0-logloss:0.12341	validation_1-logloss:0.14001
[132]	validation_0-logloss:0.12341	validation_1-logloss:0.14001
[133]	validation_0-logloss:0.12341	validation_1-logloss:0.14001
[134]	validation_0-logloss:0.12336	validation_1-logloss:0.14000
[135]	validation_0-logloss:0.12315	validation_1-logloss:0.13985
[136]	validation_0-logloss:0.12315	validation_1-logloss:0.13985
[137]	validation_0-logloss:0.12281	validation_1-logloss:0.13958
[138]	validation_0-logloss:0.12281	validation_1-logloss:0.13958
[139]	validation_0-logloss:0.12281	validation_1-logloss:0.13958
[140]	validation_0-logloss:0.12281	validation_1-logloss:0.13958
[141]	validation_0-logloss:0.12270	validation_1-logloss:0.13948
[142]	validation_0-logloss:0.12262	validation_1-logloss:0.13947
[143]	validation_0-logloss:0.12247	validation_1-logloss:0.13936
[144]	validation_0-logloss:0.12247	validation_1-logloss:0.13936
[145]	validation_0-logloss:0.12243	validation_1-logloss:0.13934
[146]	validation_0-logloss:0.12237	validation_1-logloss:0.13936
[147]	validation_0-logloss:0.12237	validation_1-logloss:0.13936
[148]	validation_0-logloss:0.12237	validation_1-logloss:0.13936
[149]	validation_0-logloss:0.12232	validation_1-logloss:0.13933
[150]	validation_0-logloss:0.12232	validation_1-logloss:0.13933
[151]	validation_0-logloss:0.12226	validation_1-logloss:0.13934
[152]	validation_0-logloss:0.12226	validation_1-logloss:0.13934
[153]	validation_0-logloss:0.12201	validation_1-logloss:0.13921
[154]	validation_0-logloss:0.12187	validation_1-logloss:0.13920
[155]	validation_0-logloss:0.12181	validation_1-logloss:0.13921
[156]	validation_0-logloss:0.12177	validation_1-logloss:0.13923
[157]	validation_0-logloss:0.12177	validation_1-logloss:0.13923
[158]	validation_0-logloss:0.12149	validation_1-logloss:0.13904
[159]	validation_0-logloss:0.12149	validation_1-logloss:0.13904
[160]	validation_0-logloss:0.12131	validation_1-logloss:0.13888
[161]	validation_0-logloss:0.12115	validation_1-logloss:0.13898
[162]	validation_0-logloss:0.12099	validation_1-logloss:0.13895
[163]	validation_0-logloss:0.12099	validation_1-logloss:0.13895
[164]	validation_0-logloss:0.12095	validation_1-logloss:0.13893
[165]	validation_0-logloss:0.12091	validation_1-logloss:0.13892
[166]	validation_0-logloss:0.12091	validation_1-logloss:0.13892
[167]	validation_0-logloss:0.12089	validation_1-logloss:0.13891
[168]	validation_0-logloss:0.12089	validation_1-logloss:0.13891
[169]	validation_0-logloss:0.12089	validation_1-logloss:0.13891
[170]	validation_0-logloss:0.12089	validation_1-logloss:0.13891
Train Gini score: 98.15%
Test Gini score: 97.30%
# check how many trees are trained
booster_ = model_gbdt.get_booster()
num_trees = booster_.best_iteration + 1
print(f"Total number of trees in the final model: {num_trees}")

# get a dump
booster_dump = booster_.get_dump()
num_trees_dump = len(booster_dump)
print(f"Total number of trees used in estimation: {num_trees_dump}")
Total number of trees in the final model: 161
Total number of trees used in estimation: 342
# get margins (log odds, leaf weights)
booster_.trees_to_dataframe()[booster_.trees_to_dataframe().Feature.isin(["Leaf"])][
    "Gain"
][:10]
9    -0.225236
10    0.190313
20   -0.143501
21   -0.134992
27   -0.000000
29   -0.222734
31    0.010724
32   -0.110243
33    0.161738
34   -0.000000
Name: Gain, dtype: float64

Loss Curve

PythonGuides.com: Matplotlib Update Plot in Loop

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

%matplotlib widget

# Import the loss curve
results = model_gbdt.evals_result()
loss_values = results["validation_1"]["logloss"]

# Create figure and subplot
fig, ax = plt.subplots(figsize=(6, 4))
x_plot = []
y_plot = []

# Plot
# (plot_1,) = ax.plot(x, y)
(plot_1,) = ax.plot(x_plot, y_plot, color="#a7fe01", label="Log Loss")
plt.axis([0, len(loss_values), min(loss_values), max(loss_values)])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Number of Trees")
ax.set_ylabel("Log Loss")
plt.title("Learning Curve")
plt.tight_layout()

# Animation Function


def animate_plot(i):
    x_plot = np.arange(i)
    y_plot = loss_values[:i]
    plot_1.set_data(x_plot, y_plot)
    return (plot_1,)


# Animated Function
ani = FuncAnimation(fig, animate_plot, frames=len(loss_values), interval=100)

# Save as gif
# path_to_export = "/Users/deburky/Documents/python/risk-practitioner-ebook/risk-practitioner/images"
# ani.save(f"{path_to_export}/learning_curve.gif", writer="pillow", fps=10, dpi=600)
# ani.save(f"learning_curve.gif", writer="pillow", fps=10, dpi=600)