2.7. TabNet#
This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. 2019) by DreamQuark
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from pytorch_tabnet.metrics import Metric
from pytorch_tabnet.tab_model import TabNetClassifier
np.random.seed(0)
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
# Generate a synthetic dataset
X, y = make_classification(
n_samples=50000, n_features=15, n_classes=2, n_informative=5, random_state=42
)
X, y = pd.DataFrame(X), pd.DataFrame(y)
X.columns = [f"feature_{i+1}" for i in range(len(X.columns))]
ix_train, ix_test = train_test_split(X.index, stratify=y, random_state=62)
class GiniScore(Metric):
def __init__(self):
self._name = "gini"
self._maximize = True
def __call__(self, y_true, y_score):
auc = roc_auc_score(y_true, y_score[:, 1])
return max(2 * auc - 1, 0.0)
max_epochs = 100
batch_size = round(0.10 * len(X.loc[ix_train]))
clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,
scheduler_params={
"is_batch_level": True,
"max_lr": 5e-2,
"steps_per_epoch": int(X.loc[ix_train].shape[0] / batch_size) + 1,
"epochs": max_epochs,
},
mask_type="entmax",
)
/Users/deburky/Library/Caches/pypoetry/virtualenvs/risk-practitioner-ebook-NcspVTUP-py3.10/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:82: UserWarning: Device used : cpu
warnings.warn(f"Device used : {self.device}")
clf.fit(
X_train=X.loc[ix_train].values,
y_train=y.loc[ix_train].values.ravel(),
eval_set=[
(X.loc[ix_train].values, y.loc[ix_train].values.ravel()),
(X.loc[ix_test].values, y.loc[ix_test].values.ravel()),
],
eval_name=["train", "test"],
eval_metric=["logloss", GiniScore],
max_epochs=max_epochs,
patience=5,
batch_size=batch_size,
virtual_batch_size=128,
num_workers=0,
weights=1,
drop_last=False,
loss_fn=nn.CrossEntropyLoss(),
)
# Make predictions using the model with two tree iterations
predictions_trn = clf.predict_proba(X.loc[ix_train].to_numpy())[:, 1]
predictions_tst = clf.predict_proba(X.loc[ix_test].to_numpy())[:, 1]
gini_trn = roc_auc_score(y.loc[ix_train], predictions_trn) * 2 - 1
gini_tst = roc_auc_score(y.loc[ix_test], predictions_tst) * 2 - 1
print(f"Train Gini score: {gini_trn:.2%}" f"Test Gini score: {gini_tst:.2%}")
epoch 0 | loss: 0.72843 | train_logloss: 0.69951 | train_gini: 0.36982 | test_logloss: 0.69944 | test_gini: 0.37703 | 0:00:03s
epoch 1 | loss: 0.60558 | train_logloss: 0.57118 | train_gini: 0.55522 | test_logloss: 0.5682 | test_gini: 0.56746 | 0:00:06s
epoch 2 | loss: 0.54865 | train_logloss: 0.50776 | train_gini: 0.65734 | test_logloss: 0.50583 | test_gini: 0.66613 | 0:00:09s
epoch 3 | loss: 0.50686 | train_logloss: 0.46785 | train_gini: 0.71664 | test_logloss: 0.46232 | test_gini: 0.72472 | 0:00:12s
epoch 4 | loss: 0.47736 | train_logloss: 0.43841 | train_gini: 0.75475 | test_logloss: 0.43204 | test_gini: 0.76403 | 0:00:15s
epoch 5 | loss: 0.44659 | train_logloss: 0.40937 | train_gini: 0.78939 | test_logloss: 0.40182 | test_gini: 0.79961 | 0:00:18s
epoch 6 | loss: 0.419 | train_logloss: 0.3838 | train_gini: 0.81762 | test_logloss: 0.37717 | test_gini: 0.82615 | 0:00:22s
epoch 7 | loss: 0.39237 | train_logloss: 0.36304 | train_gini: 0.83837 | test_logloss: 0.3584 | test_gini: 0.84493 | 0:00:25s
epoch 8 | loss: 0.37067 | train_logloss: 0.34365 | train_gini: 0.85592 | test_logloss: 0.341 | test_gini: 0.86041 | 0:00:28s
epoch 9 | loss: 0.3493 | train_logloss: 0.327 | train_gini: 0.87031 | test_logloss: 0.32593 | test_gini: 0.87321 | 0:00:31s
epoch 10 | loss: 0.33617 | train_logloss: 0.30814 | train_gini: 0.88638 | test_logloss: 0.30851 | test_gini: 0.88719 | 0:00:34s
epoch 11 | loss: 0.31625 | train_logloss: 0.29188 | train_gini: 0.8997 | test_logloss: 0.29363 | test_gini: 0.89853 | 0:00:37s
epoch 12 | loss: 0.30028 | train_logloss: 0.27471 | train_gini: 0.91155 | test_logloss: 0.27766 | test_gini: 0.90963 | 0:00:41s
epoch 13 | loss: 0.27936 | train_logloss: 0.26054 | train_gini: 0.92094 | test_logloss: 0.26316 | test_gini: 0.9197 | 0:00:45s
epoch 14 | loss: 0.26941 | train_logloss: 0.24573 | train_gini: 0.92986 | test_logloss: 0.24958 | test_gini: 0.92783 | 0:00:48s
epoch 15 | loss: 0.25454 | train_logloss: 0.23231 | train_gini: 0.93694 | test_logloss: 0.23713 | test_gini: 0.93471 | 0:00:51s
epoch 16 | loss: 0.24335 | train_logloss: 0.21988 | train_gini: 0.9432 | test_logloss: 0.22453 | test_gini: 0.94108 | 0:00:55s
epoch 17 | loss: 0.23444 | train_logloss: 0.20714 | train_gini: 0.94936 | test_logloss: 0.21384 | test_gini: 0.94603 | 0:00:58s
epoch 18 | loss: 0.22159 | train_logloss: 0.19731 | train_gini: 0.95398 | test_logloss: 0.20407 | test_gini: 0.95026 | 0:01:02s
epoch 19 | loss: 0.21177 | train_logloss: 0.18779 | train_gini: 0.95784 | test_logloss: 0.19429 | test_gini: 0.95422 | 0:01:05s
epoch 20 | loss: 0.20712 | train_logloss: 0.18069 | train_gini: 0.96036 | test_logloss: 0.18705 | test_gini: 0.95675 | 0:01:08s
epoch 21 | loss: 0.19632 | train_logloss: 0.1739 | train_gini: 0.96289 | test_logloss: 0.17905 | test_gini: 0.96002 | 0:01:12s
epoch 22 | loss: 0.1926 | train_logloss: 0.16838 | train_gini: 0.96485 | test_logloss: 0.17393 | test_gini: 0.96193 | 0:01:15s
epoch 23 | loss: 0.18361 | train_logloss: 0.16286 | train_gini: 0.96704 | test_logloss: 0.16845 | test_gini: 0.96444 | 0:01:18s
epoch 24 | loss: 0.18179 | train_logloss: 0.15855 | train_gini: 0.96834 | test_logloss: 0.16309 | test_gini: 0.96618 | 0:01:22s
epoch 25 | loss: 0.17338 | train_logloss: 0.1546 | train_gini: 0.96957 | test_logloss: 0.15942 | test_gini: 0.96713 | 0:01:25s
epoch 26 | loss: 0.1654 | train_logloss: 0.15049 | train_gini: 0.97089 | test_logloss: 0.15515 | test_gini: 0.96858 | 0:01:28s
epoch 27 | loss: 0.16682 | train_logloss: 0.14726 | train_gini: 0.97202 | test_logloss: 0.15245 | test_gini: 0.96925 | 0:01:31s
epoch 28 | loss: 0.16691 | train_logloss: 0.1454 | train_gini: 0.97242 | test_logloss: 0.15098 | test_gini: 0.96948 | 0:01:35s
epoch 29 | loss: 0.1584 | train_logloss: 0.1432 | train_gini: 0.97341 | test_logloss: 0.14998 | test_gini: 0.97014 | 0:01:38s
epoch 30 | loss: 0.15784 | train_logloss: 0.14012 | train_gini: 0.97434 | test_logloss: 0.14545 | test_gini: 0.97137 | 0:01:41s
epoch 31 | loss: 0.14957 | train_logloss: 0.13832 | train_gini: 0.97481 | test_logloss: 0.1438 | test_gini: 0.97188 | 0:01:44s
epoch 32 | loss: 0.15002 | train_logloss: 0.13659 | train_gini: 0.97534 | test_logloss: 0.1422 | test_gini: 0.97249 | 0:01:47s
epoch 33 | loss: 0.15564 | train_logloss: 0.13548 | train_gini: 0.97571 | test_logloss: 0.14285 | test_gini: 0.97223 | 0:01:50s
epoch 34 | loss: 0.14774 | train_logloss: 0.13269 | train_gini: 0.97659 | test_logloss: 0.13902 | test_gini: 0.97339 | 0:01:54s
epoch 35 | loss: 0.14647 | train_logloss: 0.13126 | train_gini: 0.97701 | test_logloss: 0.13874 | test_gini: 0.9734 | 0:01:57s
epoch 36 | loss: 0.14431 | train_logloss: 0.1302 | train_gini: 0.97719 | test_logloss: 0.13659 | test_gini: 0.97392 | 0:02:00s
epoch 37 | loss: 0.14394 | train_logloss: 0.12839 | train_gini: 0.9778 | test_logloss: 0.1339 | test_gini: 0.97494 | 0:02:04s
epoch 38 | loss: 0.14519 | train_logloss: 0.12697 | train_gini: 0.97837 | test_logloss: 0.13303 | test_gini: 0.97546 | 0:02:07s
epoch 39 | loss: 0.14045 | train_logloss: 0.12553 | train_gini: 0.97848 | test_logloss: 0.13164 | test_gini: 0.97542 | 0:02:10s
epoch 40 | loss: 0.14314 | train_logloss: 0.12385 | train_gini: 0.9789 | test_logloss: 0.13048 | test_gini: 0.97564 | 0:02:13s
epoch 41 | loss: 0.14054 | train_logloss: 0.12264 | train_gini: 0.97934 | test_logloss: 0.13055 | test_gini: 0.97566 | 0:02:17s
epoch 42 | loss: 0.13969 | train_logloss: 0.12265 | train_gini: 0.97932 | test_logloss: 0.12984 | test_gini: 0.97603 | 0:02:20s
epoch 43 | loss: 0.1325 | train_logloss: 0.12073 | train_gini: 0.98002 | test_logloss: 0.12825 | test_gini: 0.97647 | 0:02:23s
epoch 44 | loss: 0.13334 | train_logloss: 0.11998 | train_gini: 0.98012 | test_logloss: 0.12758 | test_gini: 0.9767 | 0:02:26s
epoch 45 | loss: 0.13415 | train_logloss: 0.11854 | train_gini: 0.98027 | test_logloss: 0.1263 | test_gini: 0.97659 | 0:02:29s
epoch 46 | loss: 0.13391 | train_logloss: 0.11763 | train_gini: 0.98035 | test_logloss: 0.12521 | test_gini: 0.97686 | 0:02:33s
epoch 47 | loss: 0.12983 | train_logloss: 0.11776 | train_gini: 0.98035 | test_logloss: 0.12609 | test_gini: 0.97682 | 0:02:36s
epoch 48 | loss: 0.12903 | train_logloss: 0.11678 | train_gini: 0.98085 | test_logloss: 0.12513 | test_gini: 0.97706 | 0:02:39s
epoch 49 | loss: 0.13373 | train_logloss: 0.11659 | train_gini: 0.98081 | test_logloss: 0.12655 | test_gini: 0.97637 | 0:02:42s
epoch 50 | loss: 0.12921 | train_logloss: 0.11574 | train_gini: 0.98115 | test_logloss: 0.12387 | test_gini: 0.97728 | 0:02:45s
epoch 51 | loss: 0.1291 | train_logloss: 0.11461 | train_gini: 0.98134 | test_logloss: 0.12383 | test_gini: 0.9773 | 0:02:49s
epoch 52 | loss: 0.124 | train_logloss: 0.11355 | train_gini: 0.98162 | test_logloss: 0.12285 | test_gini: 0.97747 | 0:02:52s
epoch 53 | loss: 0.12357 | train_logloss: 0.11361 | train_gini: 0.98155 | test_logloss: 0.12443 | test_gini: 0.97711 | 0:02:55s
epoch 54 | loss: 0.12677 | train_logloss: 0.11231 | train_gini: 0.98209 | test_logloss: 0.12153 | test_gini: 0.97792 | 0:02:58s
epoch 55 | loss: 0.12468 | train_logloss: 0.11166 | train_gini: 0.98216 | test_logloss: 0.12045 | test_gini: 0.97837 | 0:03:02s
epoch 56 | loss: 0.1228 | train_logloss: 0.11132 | train_gini: 0.9822 | test_logloss: 0.12091 | test_gini: 0.9782 | 0:03:05s
epoch 57 | loss: 0.12196 | train_logloss: 0.11084 | train_gini: 0.98226 | test_logloss: 0.1205 | test_gini: 0.97848 | 0:03:08s
epoch 58 | loss: 0.11996 | train_logloss: 0.111 | train_gini: 0.98221 | test_logloss: 0.11995 | test_gini: 0.97839 | 0:03:12s
epoch 59 | loss: 0.12504 | train_logloss: 0.1104 | train_gini: 0.98247 | test_logloss: 0.12089 | test_gini: 0.97817 | 0:03:15s
epoch 60 | loss: 0.12364 | train_logloss: 0.1101 | train_gini: 0.9828 | test_logloss: 0.12008 | test_gini: 0.97847 | 0:03:18s
epoch 61 | loss: 0.1236 | train_logloss: 0.10932 | train_gini: 0.98287 | test_logloss: 0.11955 | test_gini: 0.97838 | 0:03:21s
epoch 62 | loss: 0.12238 | train_logloss: 0.10917 | train_gini: 0.98286 | test_logloss: 0.12113 | test_gini: 0.97794 | 0:03:24s
Early stopping occurred at epoch 62 with best_epoch = 57 and best_test_gini = 0.97848
/Users/deburky/Library/Caches/pypoetry/virtualenvs/risk-practitioner-ebook-NcspVTUP-py3.10/lib/python3.10/site-packages/pytorch_tabnet/callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!
warnings.warn(wrn_msg)
Train Gini score: 96.45%
Test Gini score: 95.70%