2.7. TabNet#

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. 2019) by DreamQuark
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from pytorch_tabnet.metrics import Metric
from pytorch_tabnet.tab_model import TabNetClassifier

np.random.seed(0)

from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
# Generate a synthetic dataset
X, y = make_classification(
    n_samples=50000, n_features=15, n_classes=2, n_informative=5, random_state=42
)

X, y = pd.DataFrame(X), pd.DataFrame(y)

X.columns = [f"feature_{i+1}" for i in range(len(X.columns))]

ix_train, ix_test = train_test_split(X.index, stratify=y, random_state=62)
class GiniScore(Metric):
    def __init__(self):
        self._name = "gini"
        self._maximize = True

    def __call__(self, y_true, y_score):
        auc = roc_auc_score(y_true, y_score[:, 1])
        return max(2 * auc - 1, 0.0)
max_epochs = 100
batch_size = round(0.10 * len(X.loc[ix_train]))
clf = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,
    scheduler_params={
        "is_batch_level": True,
        "max_lr": 5e-2,
        "steps_per_epoch": int(X.loc[ix_train].shape[0] / batch_size) + 1,
        "epochs": max_epochs,
    },
    mask_type="entmax",
)
/Users/deburky/Library/Caches/pypoetry/virtualenvs/risk-practitioner-ebook-NcspVTUP-py3.10/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py:82: UserWarning: Device used : cpu
  warnings.warn(f"Device used : {self.device}")
clf.fit(
    X_train=X.loc[ix_train].values,
    y_train=y.loc[ix_train].values.ravel(),
    eval_set=[
        (X.loc[ix_train].values, y.loc[ix_train].values.ravel()),
        (X.loc[ix_test].values, y.loc[ix_test].values.ravel()),
    ],
    eval_name=["train", "test"],
    eval_metric=["logloss", GiniScore],
    max_epochs=max_epochs,
    patience=5,
    batch_size=batch_size,
    virtual_batch_size=128,
    num_workers=0,
    weights=1,
    drop_last=False,
    loss_fn=nn.CrossEntropyLoss(),
)

# Make predictions using the model with two tree iterations
predictions_trn = clf.predict_proba(X.loc[ix_train].to_numpy())[:, 1]
predictions_tst = clf.predict_proba(X.loc[ix_test].to_numpy())[:, 1]

gini_trn = roc_auc_score(y.loc[ix_train], predictions_trn) * 2 - 1
gini_tst = roc_auc_score(y.loc[ix_test], predictions_tst) * 2 - 1

print(f"Train Gini score: {gini_trn:.2%}" f"Test Gini score: {gini_tst:.2%}")
epoch 0  | loss: 0.72843 | train_logloss: 0.69951 | train_gini: 0.36982 | test_logloss: 0.69944 | test_gini: 0.37703 |  0:00:03s
epoch 1  | loss: 0.60558 | train_logloss: 0.57118 | train_gini: 0.55522 | test_logloss: 0.5682  | test_gini: 0.56746 |  0:00:06s
epoch 2  | loss: 0.54865 | train_logloss: 0.50776 | train_gini: 0.65734 | test_logloss: 0.50583 | test_gini: 0.66613 |  0:00:09s
epoch 3  | loss: 0.50686 | train_logloss: 0.46785 | train_gini: 0.71664 | test_logloss: 0.46232 | test_gini: 0.72472 |  0:00:12s
epoch 4  | loss: 0.47736 | train_logloss: 0.43841 | train_gini: 0.75475 | test_logloss: 0.43204 | test_gini: 0.76403 |  0:00:15s
epoch 5  | loss: 0.44659 | train_logloss: 0.40937 | train_gini: 0.78939 | test_logloss: 0.40182 | test_gini: 0.79961 |  0:00:18s
epoch 6  | loss: 0.419   | train_logloss: 0.3838  | train_gini: 0.81762 | test_logloss: 0.37717 | test_gini: 0.82615 |  0:00:22s
epoch 7  | loss: 0.39237 | train_logloss: 0.36304 | train_gini: 0.83837 | test_logloss: 0.3584  | test_gini: 0.84493 |  0:00:25s
epoch 8  | loss: 0.37067 | train_logloss: 0.34365 | train_gini: 0.85592 | test_logloss: 0.341   | test_gini: 0.86041 |  0:00:28s
epoch 9  | loss: 0.3493  | train_logloss: 0.327   | train_gini: 0.87031 | test_logloss: 0.32593 | test_gini: 0.87321 |  0:00:31s
epoch 10 | loss: 0.33617 | train_logloss: 0.30814 | train_gini: 0.88638 | test_logloss: 0.30851 | test_gini: 0.88719 |  0:00:34s
epoch 11 | loss: 0.31625 | train_logloss: 0.29188 | train_gini: 0.8997  | test_logloss: 0.29363 | test_gini: 0.89853 |  0:00:37s
epoch 12 | loss: 0.30028 | train_logloss: 0.27471 | train_gini: 0.91155 | test_logloss: 0.27766 | test_gini: 0.90963 |  0:00:41s
epoch 13 | loss: 0.27936 | train_logloss: 0.26054 | train_gini: 0.92094 | test_logloss: 0.26316 | test_gini: 0.9197  |  0:00:45s
epoch 14 | loss: 0.26941 | train_logloss: 0.24573 | train_gini: 0.92986 | test_logloss: 0.24958 | test_gini: 0.92783 |  0:00:48s
epoch 15 | loss: 0.25454 | train_logloss: 0.23231 | train_gini: 0.93694 | test_logloss: 0.23713 | test_gini: 0.93471 |  0:00:51s
epoch 16 | loss: 0.24335 | train_logloss: 0.21988 | train_gini: 0.9432  | test_logloss: 0.22453 | test_gini: 0.94108 |  0:00:55s
epoch 17 | loss: 0.23444 | train_logloss: 0.20714 | train_gini: 0.94936 | test_logloss: 0.21384 | test_gini: 0.94603 |  0:00:58s
epoch 18 | loss: 0.22159 | train_logloss: 0.19731 | train_gini: 0.95398 | test_logloss: 0.20407 | test_gini: 0.95026 |  0:01:02s
epoch 19 | loss: 0.21177 | train_logloss: 0.18779 | train_gini: 0.95784 | test_logloss: 0.19429 | test_gini: 0.95422 |  0:01:05s
epoch 20 | loss: 0.20712 | train_logloss: 0.18069 | train_gini: 0.96036 | test_logloss: 0.18705 | test_gini: 0.95675 |  0:01:08s
epoch 21 | loss: 0.19632 | train_logloss: 0.1739  | train_gini: 0.96289 | test_logloss: 0.17905 | test_gini: 0.96002 |  0:01:12s
epoch 22 | loss: 0.1926  | train_logloss: 0.16838 | train_gini: 0.96485 | test_logloss: 0.17393 | test_gini: 0.96193 |  0:01:15s
epoch 23 | loss: 0.18361 | train_logloss: 0.16286 | train_gini: 0.96704 | test_logloss: 0.16845 | test_gini: 0.96444 |  0:01:18s
epoch 24 | loss: 0.18179 | train_logloss: 0.15855 | train_gini: 0.96834 | test_logloss: 0.16309 | test_gini: 0.96618 |  0:01:22s
epoch 25 | loss: 0.17338 | train_logloss: 0.1546  | train_gini: 0.96957 | test_logloss: 0.15942 | test_gini: 0.96713 |  0:01:25s
epoch 26 | loss: 0.1654  | train_logloss: 0.15049 | train_gini: 0.97089 | test_logloss: 0.15515 | test_gini: 0.96858 |  0:01:28s
epoch 27 | loss: 0.16682 | train_logloss: 0.14726 | train_gini: 0.97202 | test_logloss: 0.15245 | test_gini: 0.96925 |  0:01:31s
epoch 28 | loss: 0.16691 | train_logloss: 0.1454  | train_gini: 0.97242 | test_logloss: 0.15098 | test_gini: 0.96948 |  0:01:35s
epoch 29 | loss: 0.1584  | train_logloss: 0.1432  | train_gini: 0.97341 | test_logloss: 0.14998 | test_gini: 0.97014 |  0:01:38s
epoch 30 | loss: 0.15784 | train_logloss: 0.14012 | train_gini: 0.97434 | test_logloss: 0.14545 | test_gini: 0.97137 |  0:01:41s
epoch 31 | loss: 0.14957 | train_logloss: 0.13832 | train_gini: 0.97481 | test_logloss: 0.1438  | test_gini: 0.97188 |  0:01:44s
epoch 32 | loss: 0.15002 | train_logloss: 0.13659 | train_gini: 0.97534 | test_logloss: 0.1422  | test_gini: 0.97249 |  0:01:47s
epoch 33 | loss: 0.15564 | train_logloss: 0.13548 | train_gini: 0.97571 | test_logloss: 0.14285 | test_gini: 0.97223 |  0:01:50s
epoch 34 | loss: 0.14774 | train_logloss: 0.13269 | train_gini: 0.97659 | test_logloss: 0.13902 | test_gini: 0.97339 |  0:01:54s
epoch 35 | loss: 0.14647 | train_logloss: 0.13126 | train_gini: 0.97701 | test_logloss: 0.13874 | test_gini: 0.9734  |  0:01:57s
epoch 36 | loss: 0.14431 | train_logloss: 0.1302  | train_gini: 0.97719 | test_logloss: 0.13659 | test_gini: 0.97392 |  0:02:00s
epoch 37 | loss: 0.14394 | train_logloss: 0.12839 | train_gini: 0.9778  | test_logloss: 0.1339  | test_gini: 0.97494 |  0:02:04s
epoch 38 | loss: 0.14519 | train_logloss: 0.12697 | train_gini: 0.97837 | test_logloss: 0.13303 | test_gini: 0.97546 |  0:02:07s
epoch 39 | loss: 0.14045 | train_logloss: 0.12553 | train_gini: 0.97848 | test_logloss: 0.13164 | test_gini: 0.97542 |  0:02:10s
epoch 40 | loss: 0.14314 | train_logloss: 0.12385 | train_gini: 0.9789  | test_logloss: 0.13048 | test_gini: 0.97564 |  0:02:13s
epoch 41 | loss: 0.14054 | train_logloss: 0.12264 | train_gini: 0.97934 | test_logloss: 0.13055 | test_gini: 0.97566 |  0:02:17s
epoch 42 | loss: 0.13969 | train_logloss: 0.12265 | train_gini: 0.97932 | test_logloss: 0.12984 | test_gini: 0.97603 |  0:02:20s
epoch 43 | loss: 0.1325  | train_logloss: 0.12073 | train_gini: 0.98002 | test_logloss: 0.12825 | test_gini: 0.97647 |  0:02:23s
epoch 44 | loss: 0.13334 | train_logloss: 0.11998 | train_gini: 0.98012 | test_logloss: 0.12758 | test_gini: 0.9767  |  0:02:26s
epoch 45 | loss: 0.13415 | train_logloss: 0.11854 | train_gini: 0.98027 | test_logloss: 0.1263  | test_gini: 0.97659 |  0:02:29s
epoch 46 | loss: 0.13391 | train_logloss: 0.11763 | train_gini: 0.98035 | test_logloss: 0.12521 | test_gini: 0.97686 |  0:02:33s
epoch 47 | loss: 0.12983 | train_logloss: 0.11776 | train_gini: 0.98035 | test_logloss: 0.12609 | test_gini: 0.97682 |  0:02:36s
epoch 48 | loss: 0.12903 | train_logloss: 0.11678 | train_gini: 0.98085 | test_logloss: 0.12513 | test_gini: 0.97706 |  0:02:39s
epoch 49 | loss: 0.13373 | train_logloss: 0.11659 | train_gini: 0.98081 | test_logloss: 0.12655 | test_gini: 0.97637 |  0:02:42s
epoch 50 | loss: 0.12921 | train_logloss: 0.11574 | train_gini: 0.98115 | test_logloss: 0.12387 | test_gini: 0.97728 |  0:02:45s
epoch 51 | loss: 0.1291  | train_logloss: 0.11461 | train_gini: 0.98134 | test_logloss: 0.12383 | test_gini: 0.9773  |  0:02:49s
epoch 52 | loss: 0.124   | train_logloss: 0.11355 | train_gini: 0.98162 | test_logloss: 0.12285 | test_gini: 0.97747 |  0:02:52s
epoch 53 | loss: 0.12357 | train_logloss: 0.11361 | train_gini: 0.98155 | test_logloss: 0.12443 | test_gini: 0.97711 |  0:02:55s
epoch 54 | loss: 0.12677 | train_logloss: 0.11231 | train_gini: 0.98209 | test_logloss: 0.12153 | test_gini: 0.97792 |  0:02:58s
epoch 55 | loss: 0.12468 | train_logloss: 0.11166 | train_gini: 0.98216 | test_logloss: 0.12045 | test_gini: 0.97837 |  0:03:02s
epoch 56 | loss: 0.1228  | train_logloss: 0.11132 | train_gini: 0.9822  | test_logloss: 0.12091 | test_gini: 0.9782  |  0:03:05s
epoch 57 | loss: 0.12196 | train_logloss: 0.11084 | train_gini: 0.98226 | test_logloss: 0.1205  | test_gini: 0.97848 |  0:03:08s
epoch 58 | loss: 0.11996 | train_logloss: 0.111   | train_gini: 0.98221 | test_logloss: 0.11995 | test_gini: 0.97839 |  0:03:12s
epoch 59 | loss: 0.12504 | train_logloss: 0.1104  | train_gini: 0.98247 | test_logloss: 0.12089 | test_gini: 0.97817 |  0:03:15s
epoch 60 | loss: 0.12364 | train_logloss: 0.1101  | train_gini: 0.9828  | test_logloss: 0.12008 | test_gini: 0.97847 |  0:03:18s
epoch 61 | loss: 0.1236  | train_logloss: 0.10932 | train_gini: 0.98287 | test_logloss: 0.11955 | test_gini: 0.97838 |  0:03:21s
epoch 62 | loss: 0.12238 | train_logloss: 0.10917 | train_gini: 0.98286 | test_logloss: 0.12113 | test_gini: 0.97794 |  0:03:24s

Early stopping occurred at epoch 62 with best_epoch = 57 and best_test_gini = 0.97848
/Users/deburky/Library/Caches/pypoetry/virtualenvs/risk-practitioner-ebook-NcspVTUP-py3.10/lib/python3.10/site-packages/pytorch_tabnet/callbacks.py:172: UserWarning: Best weights from best epoch are automatically used!
  warnings.warn(wrn_msg)
Train Gini score: 96.45%
Test Gini score: 95.70%