TabNet - simple binary classification
reference:
https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb
steps:
- download market data using yfinance: download S&P 500 (‘^GSPC')
- calculate return 20-day max return (i.e. target in supervised learning problem):
- for each date (T):
- calculate the max price change in next 20 trading dates: price_change = (max{close price in T+1 to T+20} - {close price on T})/({close price on T})
- for each date (T):
- convert the 20-day max return into binary target
- engineer a few features
- lag21: previous 21 day target
- lag31: previous 31 day target
- lag41: previous 41 day target
- day price change: the difference between open and closing prices
- (Close - Open)/Open
- day max price change: the difference between high and low prices
- (High-Low)/Open
- one day close price change: day T close price versus day T-1 close price.
- 100*({Close on T} - {Close on T-1})/{Close on T-1}
- 10 day close price change: day T close price versus day T-10 close price.
- 100*({Close on T} - {Close on T-10})/{Close on T-10}
- 20 day close price change: day T close price versus day T-20 close price.
- 100*({Close on T} - {Close on T-20})/{Close on T-20}
- one day/10day/20day volume change
- feed data into tabnet classifier
- visualize the loss/performance in each epoch html
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import yfinance as yf #to download stock price data
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import os
from pathlib import Path
import shutil
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
#initiate random seed
import random
def init_seed(random_seed):
random.seed(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
init_seed(5678)
download S&P 500 price data
ticker = '^GSPC'
cur_data = yf.Ticker(ticker)
hist = cur_data.history(period="max")
print(ticker, hist.shape, hist.index.min())
^GSPC (19721, 7) 1927-12-30 00:00:00
df=hist[hist.index>='2000-01-01'].copy(deep=True)
df.head()
Open | High | Low | Close | Volume | Dividends | Stock Splits | |
---|---|---|---|---|---|---|---|
Date | |||||||
2000-01-03 | 1469.250000 | 1478.000000 | 1438.359985 | 1455.219971 | 931800000 | 0 | 0 |
2000-01-04 | 1455.219971 | 1455.219971 | 1397.430054 | 1399.420044 | 1009000000 | 0 | 0 |
2000-01-05 | 1399.420044 | 1413.270020 | 1377.680054 | 1402.109985 | 1085500000 | 0 | 0 |
2000-01-06 | 1402.109985 | 1411.900024 | 1392.099976 | 1403.449951 | 1092300000 | 0 | 0 |
2000-01-07 | 1403.449951 | 1441.469971 | 1400.729980 | 1441.469971 | 1225200000 | 0 | 0 |
calcualte max return in next 20 trading days
#for each stock_id, get the max close in next 20 trading days
price_col = 'Close'
roll_len=20
new_col = 'next_20day_max'
target_list = []
df.sort_index(ascending=True, inplace=True)
df.head(3)
Open | High | Low | Close | Volume | Dividends | Stock Splits | |
---|---|---|---|---|---|---|---|
Date | |||||||
2000-01-03 | 1469.250000 | 1478.000000 | 1438.359985 | 1455.219971 | 931800000 | 0 | 0 |
2000-01-04 | 1455.219971 | 1455.219971 | 1397.430054 | 1399.420044 | 1009000000 | 0 | 0 |
2000-01-05 | 1399.420044 | 1413.270020 | 1377.680054 | 1402.109985 | 1085500000 | 0 | 0 |
df_next20dmax=df[[price_col]].shift(1).rolling(roll_len).max()
df_next20dmax.columns=[new_col]
df = df.merge(df_next20dmax, right_index=True, left_index=True, how='inner')
df.dropna(how='any', inplace=True)
df['target']= 100*(df[new_col]-df[price_col])/df[price_col]
df['target'].describe()
count 5479.000000
mean 2.450868
std 4.077580
min -3.743456
25% 0.135604
50% 1.130147
75% 3.318523
max 44.809803
Name: target, dtype: float64
df['target'].hist(bins=100)
<AxesSubplot:>
df['binary_target'] = 0
df.loc[df['target']>5, 'binary_target'] = 1
df['binary_target'].value_counts()
0 4643
1 836
Name: binary_target, dtype: int64
df.head(3)
Open | High | Low | Close | Volume | Dividends | Stock Splits | next_20day_max | target | binary_target | |
---|---|---|---|---|---|---|---|---|---|---|
Date | ||||||||||
2000-02-01 | 1394.459961 | 1412.489990 | 1384.790039 | 1409.280029 | 981000000 | 0 | 0 | 1465.150024 | 3.964435 | 0 |
2000-02-02 | 1409.280029 | 1420.609985 | 1403.489990 | 1409.119995 | 1038600000 | 0 | 0 | 1465.150024 | 3.976243 | 0 |
2000-02-03 | 1409.119995 | 1425.780029 | 1398.520020 | 1424.969971 | 1146500000 | 0 | 0 | 1465.150024 | 2.819712 | 0 |
create additional input features
df['lag21']=df['target'].shift(21)
df['lag31']=df['target'].shift(31)
df['lag41']=df['target'].shift(41)
df['open_close_diff'] = df['Close'] - df['Open']
df['day_change']=(100*df['open_close_diff']/df['Open']).round(2)
df['day_max_change'] = (100*(df['High'] - df['Low'])/df['Open']).round(2)
#create a binary feature: 1 day change
#0: decrease; 1: increase
df['oneday_change']=(df['Close'].diff()>0)+1-1
df['10day_change']=df['Close'].diff(10)
df['20day_change']=df['Close'].diff(20)
df['oneday_volchange']=(df['Volume'].diff()>0)+1-1
df['10day_volchange']=df['Volume'].diff(10)
df['20day_volchange']=df['Volume'].diff(20)
df.head(3)
Open | High | Low | Close | Volume | Dividends | Stock Splits | next_20day_max | target | binary_target | ... | lag41 | open_close_diff | day_change | day_max_change | oneday_change | 10day_change | 20day_change | oneday_volchange | 10day_volchange | 20day_volchange | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Date | |||||||||||||||||||||
2000-02-01 | 1394.459961 | 1412.489990 | 1384.790039 | 1409.280029 | 981000000 | 0 | 0 | 1465.150024 | 3.964435 | 0 | ... | NaN | 14.820068 | 1.06 | 1.99 | 0 | NaN | NaN | 0 | NaN | NaN |
2000-02-02 | 1409.280029 | 1420.609985 | 1403.489990 | 1409.119995 | 1038600000 | 0 | 0 | 1465.150024 | 3.976243 | 0 | ... | NaN | -0.160034 | -0.01 | 1.21 | 0 | NaN | NaN | 1 | NaN | NaN |
2000-02-03 | 1409.119995 | 1425.780029 | 1398.520020 | 1424.969971 | 1146500000 | 0 | 0 | 1465.150024 | 2.819712 | 0 | ... | NaN | 15.849976 | 1.12 | 1.93 | 1 | NaN | NaN | 1 | NaN | NaN |
3 rows × 22 columns
df['day_change'].hist(bins=50)
<AxesSubplot:>
#convert day_change into categorical feature
#above 2- class 1; below -2 - class -1, in the middle - class0
df['day_change_cat']=0
df.loc[df['day_change']<=-2, 'day_change_cat']=-1
df.loc[df['day_change']>=2, 'day_change_cat']=1
df['day_change_cat'].value_counts()
0 5095
-1 210
1 174
Name: day_change_cat, dtype: int64
df.dropna(how='any', inplace=True)
print(df.shape, df.index.min())
df.head(3)
(5438, 23) 2000-03-30 00:00:00
Open | High | Low | Close | Volume | Dividends | Stock Splits | next_20day_max | target | binary_target | ... | open_close_diff | day_change | day_max_change | oneday_change | 10day_change | 20day_change | oneday_volchange | 10day_volchange | 20day_volchange | day_change_cat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Date | |||||||||||||||||||||
2000-03-30 | 1508.520020 | 1517.380005 | 1474.630005 | 1487.920044 | 1193400000 | 0 | 0 | 1527.459961 | 2.657395 | 0 | ... | -20.599976 | -1.37 | 2.83 | 0 | 29.450073 | 106.160034 | 1 | -288900000.0 | -5200000.0 | 0 |
2000-03-31 | 1487.920044 | 1519.810059 | 1484.380005 | 1498.579956 | 1227400000 | 0 | 0 | 1527.459961 | 1.927158 | 0 | ... | 10.659912 | 0.72 | 2.38 | 1 | 34.109985 | 89.409912 | 1 | -67700000.0 | 77100000.0 | 0 |
2000-04-03 | 1498.579956 | 1507.189941 | 1486.959961 | 1505.969971 | 1021700000 | 0 | 0 | 1527.459961 | 1.426987 | 0 | ... | 7.390015 | 0.49 | 1.35 | 1 | 49.339966 | 114.689941 | 0 | 100900000.0 | -7300000.0 | 0 |
3 rows × 23 columns
split data into simple training and testing subsets
target='binary_target'
bool_columns = ['oneday_change', 'oneday_volchange']
df.dropna(how='any', inplace=True)
train = df.copy(deep=True)
Simple preprocessing
Label encode categorical features and fill empty cells.
categorical_columns = ['day_change_cat']
categorical_dims = {}
for col in categorical_columns:
print(col, train[col].nunique())
l_enc = LabelEncoder()
train[col] = l_enc.fit_transform(train[col].values)
categorical_dims[col] = len(l_enc.classes_)
categorical_dims
day_change_cat 3
{'day_change_cat': 3}
categorical_columns, categorical_dims
(['day_change_cat'], {'day_change_cat': 3})
Define categorical features for categorical embeddings
unused_feat = ['Dividends', 'Stock Splits', 'next_20day_max',
'open_close_diff', 'day_change' ]
features = [ col for col in train.columns if col not in unused_feat+[target]]
cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]
cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]
print(features)
['Open', 'High', 'Low', 'Close', 'Volume', 'target', 'lag21', 'lag31', 'lag41', 'day_max_change', 'oneday_change', '10day_change', '20day_change', 'oneday_volchange', '10day_volchange', '20day_volchange', 'day_change_cat']
cat_idxs
[16]
cat_dims
[3]
Network parameters
clf = TabNetClassifier(
n_d=32, n_a=32, n_steps=5,
gamma=1.5, n_independent=2, n_shared=2,
cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=1,
lambda_sparse=1e-4, momentum=0.3, clip_value=2.,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params = {"gamma": 0.95,
"step_size": 20},
scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15
)
Device used : cpu
clf2 = TabNetClassifier()
clf2
Device used : cpu
TabNetClassifier(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=0, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.02}, scheduler_fn=None, scheduler_params={}, mask_type='sparsemax', input_dim=None, output_dim=None, device_name='auto')
Training
train.shape
(5438, 23)
X_train = train[features].values[:-1500,:]
y_train = train[target].values[:-1500]
X_valid = train[features].values[-1450:-650,:]
y_valid = train[target].values[-1450:-650]
X_test = train[features].values[-600:, ]
y_test = train[target].values[-600:]
X_train.shape, X_valid.shape, X_test.shape
((3938, 17), (800, 17), (600, 17))
max_epochs = 50
clf.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
max_epochs=max_epochs, patience=100,
batch_size=1024, virtual_batch_size=256
)
epoch 0 | loss: 0.6085 | train_auc: 0.5946 | valid_auc: 0.55134 | 0:00:00s
epoch 1 | loss: 0.25079 | train_auc: 0.72439 | valid_auc: 0.62383 | 0:00:01s
epoch 2 | loss: 0.20278 | train_auc: 0.74185 | valid_auc: 0.67725 | 0:00:02s
epoch 3 | loss: 0.16259 | train_auc: 0.77868 | valid_auc: 0.73814 | 0:00:03s
epoch 4 | loss: 0.1455 | train_auc: 0.823 | valid_auc: 0.88184 | 0:00:03s
epoch 5 | loss: 0.12248 | train_auc: 0.86481 | valid_auc: 0.80974 | 0:00:04s
epoch 6 | loss: 0.12148 | train_auc: 0.87193 | valid_auc: 0.73214 | 0:00:05s
epoch 7 | loss: 0.09959 | train_auc: 0.92461 | valid_auc: 0.855 | 0:00:05s
epoch 8 | loss: 0.1053 | train_auc: 0.90809 | valid_auc: 0.90169 | 0:00:06s
epoch 9 | loss: 0.10171 | train_auc: 0.92914 | valid_auc: 0.84598 | 0:00:07s
epoch 10 | loss: 0.09968 | train_auc: 0.88691 | valid_auc: 0.92718 | 0:00:07s
epoch 11 | loss: 0.07191 | train_auc: 0.88264 | valid_auc: 0.90451 | 0:00:08s
epoch 12 | loss: 0.07202 | train_auc: 0.89332 | valid_auc: 0.91838 | 0:00:09s
epoch 13 | loss: 0.06804 | train_auc: 0.9532 | valid_auc: 0.9241 | 0:00:10s
epoch 14 | loss: 0.05959 | train_auc: 0.97328 | valid_auc: 0.9259 | 0:00:10s
epoch 15 | loss: 0.05237 | train_auc: 0.9798 | valid_auc: 0.92061 | 0:00:11s
epoch 16 | loss: 0.0527 | train_auc: 0.98212 | valid_auc: 0.95056 | 0:00:12s
epoch 17 | loss: 0.04247 | train_auc: 0.98444 | valid_auc: 0.96823 | 0:00:12s
epoch 18 | loss: 0.04855 | train_auc: 0.98537 | valid_auc: 0.91772 | 0:00:13s
epoch 19 | loss: 0.05294 | train_auc: 0.98751 | valid_auc: 0.97337 | 0:00:14s
epoch 20 | loss: 0.05513 | train_auc: 0.98903 | valid_auc: 0.95136 | 0:00:15s
epoch 21 | loss: 0.05105 | train_auc: 0.98586 | valid_auc: 0.95794 | 0:00:15s
epoch 22 | loss: 0.05698 | train_auc: 0.98526 | valid_auc: 0.95633 | 0:00:16s
epoch 23 | loss: 0.04063 | train_auc: 0.98628 | valid_auc: 0.94675 | 0:00:17s
epoch 24 | loss: 0.04106 | train_auc: 0.99017 | valid_auc: 0.96589 | 0:00:17s
epoch 25 | loss: 0.04789 | train_auc: 0.99188 | valid_auc: 0.96874 | 0:00:18s
epoch 26 | loss: 0.03689 | train_auc: 0.99214 | valid_auc: 0.95453 | 0:00:19s
epoch 27 | loss: 0.04332 | train_auc: 0.99126 | valid_auc: 0.96753 | 0:00:20s
epoch 28 | loss: 0.04386 | train_auc: 0.9947 | valid_auc: 0.98204 | 0:00:20s
epoch 29 | loss: 0.03909 | train_auc: 0.99469 | valid_auc: 0.99208 | 0:00:21s
epoch 30 | loss: 0.04479 | train_auc: 0.99478 | valid_auc: 0.9625 | 0:00:22s
epoch 31 | loss: 0.04565 | train_auc: 0.99615 | valid_auc: 0.9235 | 0:00:22s
epoch 32 | loss: 0.03824 | train_auc: 0.99425 | valid_auc: 0.98025 | 0:00:23s
epoch 33 | loss: 0.02534 | train_auc: 0.99519 | valid_auc: 0.99611 | 0:00:24s
epoch 34 | loss: 0.02821 | train_auc: 0.99603 | valid_auc: 0.99306 | 0:00:25s
epoch 35 | loss: 0.03126 | train_auc: 0.99478 | valid_auc: 0.99035 | 0:00:25s
epoch 36 | loss: 0.02935 | train_auc: 0.9962 | valid_auc: 0.9991 | 0:00:26s
epoch 37 | loss: 0.04743 | train_auc: 0.99714 | valid_auc: 0.99868 | 0:00:27s
epoch 38 | loss: 0.04275 | train_auc: 0.99752 | valid_auc: 0.99935 | 0:00:27s
epoch 39 | loss: 0.03126 | train_auc: 0.99728 | valid_auc: 0.99664 | 0:00:28s
epoch 40 | loss: 0.03763 | train_auc: 0.99903 | valid_auc: 0.99352 | 0:00:29s
epoch 41 | loss: 0.05079 | train_auc: 0.999 | valid_auc: 0.9938 | 0:00:29s
epoch 42 | loss: 0.03019 | train_auc: 0.99848 | valid_auc: 0.99407 | 0:00:30s
epoch 43 | loss: 0.02432 | train_auc: 0.99936 | valid_auc: 0.98669 | 0:00:31s
epoch 44 | loss: 0.04208 | train_auc: 0.99914 | valid_auc: 0.98569 | 0:00:32s
epoch 45 | loss: 0.03639 | train_auc: 0.99944 | valid_auc: 0.98695 | 0:00:32s
epoch 46 | loss: 0.03375 | train_auc: 0.99922 | valid_auc: 0.9797 | 0:00:33s
epoch 47 | loss: 0.03105 | train_auc: 0.99863 | valid_auc: 0.98265 | 0:00:34s
epoch 48 | loss: 0.02795 | train_auc: 0.99946 | valid_auc: 0.98412 | 0:00:34s
epoch 49 | loss: 0.03275 | train_auc: 0.9993 | valid_auc: 0.98 | 0:00:35s
Stop training because you reached max_epochs = 50 with best_epoch = 38 and best_valid_auc = 0.99935
Best weights from best epoch are automatically used!
clf2.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
max_epochs=max_epochs, patience=50,
batch_size=1024, virtual_batch_size=128
)
epoch 0 | loss: 0.04553 | train_auc: 0.98935 | valid_auc: 0.99878 | 0:00:00s
epoch 1 | loss: 0.04278 | train_auc: 0.98504 | valid_auc: 0.97955 | 0:00:00s
epoch 2 | loss: 0.02881 | train_auc: 0.98514 | valid_auc: 0.99199 | 0:00:00s
epoch 3 | loss: 0.04076 | train_auc: 0.98423 | valid_auc: 0.99251 | 0:00:01s
epoch 4 | loss: 0.03414 | train_auc: 0.98811 | valid_auc: 0.99318 | 0:00:01s
epoch 5 | loss: 0.03644 | train_auc: 0.99269 | valid_auc: 0.99511 | 0:00:01s
epoch 6 | loss: 0.02725 | train_auc: 0.9956 | valid_auc: 0.99439 | 0:00:01s
epoch 7 | loss: 0.0272 | train_auc: 0.99641 | valid_auc: 0.99542 | 0:00:02s
epoch 8 | loss: 0.03312 | train_auc: 0.99755 | valid_auc: 0.99533 | 0:00:02s
epoch 9 | loss: 0.03144 | train_auc: 0.9983 | valid_auc: 0.99587 | 0:00:02s
epoch 10 | loss: 0.03173 | train_auc: 0.99891 | valid_auc: 0.9969 | 0:00:02s
epoch 11 | loss: 0.02985 | train_auc: 0.99912 | valid_auc: 0.99827 | 0:00:03s
epoch 12 | loss: 0.02482 | train_auc: 0.99905 | valid_auc: 0.99755 | 0:00:03s
epoch 13 | loss: 0.02069 | train_auc: 0.99904 | valid_auc: 0.99718 | 0:00:03s
epoch 14 | loss: 0.02969 | train_auc: 0.99907 | valid_auc: 0.99257 | 0:00:03s
epoch 15 | loss: 0.02153 | train_auc: 0.999 | valid_auc: 0.98468 | 0:00:04s
epoch 16 | loss: 0.01556 | train_auc: 0.99883 | valid_auc: 0.97252 | 0:00:04s
epoch 17 | loss: 0.02614 | train_auc: 0.99903 | valid_auc: 0.94694 | 0:00:04s
epoch 18 | loss: 0.02147 | train_auc: 0.99924 | valid_auc: 0.9476 | 0:00:05s
epoch 19 | loss: 0.03949 | train_auc: 0.99965 | valid_auc: 0.97237 | 0:00:05s
epoch 20 | loss: 0.03701 | train_auc: 0.99963 | valid_auc: 0.9469 | 0:00:05s
epoch 21 | loss: 0.05568 | train_auc: 0.99977 | valid_auc: 0.93323 | 0:00:06s
epoch 22 | loss: 0.02871 | train_auc: 0.99986 | valid_auc: 0.94559 | 0:00:06s
epoch 23 | loss: 0.02558 | train_auc: 0.99988 | valid_auc: 1.0 | 0:00:06s
epoch 24 | loss: 0.01854 | train_auc: 0.9999 | valid_auc: 1.0 | 0:00:06s
epoch 25 | loss: 0.04108 | train_auc: 0.99992 | valid_auc: 1.0 | 0:00:07s
epoch 26 | loss: 0.02608 | train_auc: 0.99991 | valid_auc: 1.0 | 0:00:07s
epoch 27 | loss: 0.0196 | train_auc: 0.99992 | valid_auc: 1.0 | 0:00:07s
epoch 28 | loss: 0.01689 | train_auc: 0.99994 | valid_auc: 1.0 | 0:00:07s
epoch 29 | loss: 0.02221 | train_auc: 0.99993 | valid_auc: 1.0 | 0:00:08s
epoch 30 | loss: 0.01336 | train_auc: 0.99988 | valid_auc: 1.0 | 0:00:08s
epoch 31 | loss: 0.01686 | train_auc: 0.99994 | valid_auc: 1.0 | 0:00:08s
epoch 32 | loss: 0.01983 | train_auc: 0.99994 | valid_auc: 1.0 | 0:00:09s
epoch 33 | loss: 0.01631 | train_auc: 0.99997 | valid_auc: 0.99957 | 0:00:09s
epoch 34 | loss: 0.02721 | train_auc: 0.99994 | valid_auc: 0.99994 | 0:00:09s
epoch 35 | loss: 0.02679 | train_auc: 0.99985 | valid_auc: 1.0 | 0:00:09s
epoch 36 | loss: 0.04 | train_auc: 0.99984 | valid_auc: 1.0 | 0:00:10s
epoch 37 | loss: 0.03121 | train_auc: 0.99996 | valid_auc: 0.99984 | 0:00:10s
epoch 38 | loss: 0.03083 | train_auc: 0.99997 | valid_auc: 0.99904 | 0:00:10s
epoch 39 | loss: 0.02723 | train_auc: 0.99994 | valid_auc: 0.99827 | 0:00:10s
epoch 40 | loss: 0.0429 | train_auc: 0.99991 | valid_auc: 0.99788 | 0:00:11s
epoch 41 | loss: 0.02929 | train_auc: 0.99994 | valid_auc: 0.99902 | 0:00:11s
epoch 42 | loss: 0.03338 | train_auc: 0.99997 | valid_auc: 1.0 | 0:00:11s
epoch 43 | loss: 0.03398 | train_auc: 0.99998 | valid_auc: 1.0 | 0:00:11s
epoch 44 | loss: 0.0267 | train_auc: 0.99995 | valid_auc: 0.9998 | 0:00:12s
epoch 45 | loss: 0.02937 | train_auc: 0.99991 | valid_auc: 0.99996 | 0:00:12s
epoch 46 | loss: 0.02838 | train_auc: 0.99996 | valid_auc: 1.0 | 0:00:12s
epoch 47 | loss: 0.01855 | train_auc: 0.99998 | valid_auc: 1.0 | 0:00:13s
epoch 48 | loss: 0.02415 | train_auc: 0.99998 | valid_auc: 1.0 | 0:00:13s
epoch 49 | loss: 0.01826 | train_auc: 0.99998 | valid_auc: 1.0 | 0:00:13s
Stop training because you reached max_epochs = 50 with best_epoch = 23 and best_valid_auc = 1.0
Best weights from best epoch are automatically used!
fig_list =[]
# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
x_vals=list(range(1, max_epochs+1))
fig.add_trace(go.Scatter(
name="loss",
mode="lines", x=x_vals, y=clf.history['loss']),
secondary_y=False
)
fig.add_trace(go.Scatter(
name="train_auc",
mode="lines", x=x_vals,y=clf.history['train_auc']),
secondary_y=True
)
fig.add_trace(go.Scatter(
name="valid_auc",
mode="lines", x=x_vals,y=clf.history['valid_auc']),
secondary_y=True
)
fig.update_layout(hovermode="x unified",
title_text="training data - loss and auc"
)
#fig.show()
fig_list.append(fig)
# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
x_vals=list(range(1, max_epochs+1))
fig.add_trace(go.Scatter(
name="loss",
mode="lines", x=x_vals, y=clf2.history['loss']),
secondary_y=False
)
fig.add_trace(go.Scatter(
name="train_auc",
mode="lines", x=x_vals,y=clf2.history['train_auc']),
secondary_y=True
)
fig.add_trace(go.Scatter(
name="valid_auc",
mode="lines", x=x_vals,y=clf2.history['valid_auc']),
secondary_y=True
)
fig.update_layout(hovermode="x unified",
title_text="training data - loss and auc - default hyperparameters"
)
#fig.show()
fig_list.append(fig)
Predictions
preds_mapper = { idx : class_name for idx, class_name in enumerate(clf.classes_)}
preds = clf.predict_proba(X_test)
y_pred = np.vectorize(preds_mapper.get)(np.argmax(preds, axis=1))
test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)
preds_mapper2 = { idx : class_name for idx, class_name in enumerate(clf2.classes_)}
preds2 = clf2.predict_proba(X_test)
y_pred2 = np.vectorize(preds_mapper2.get)(np.argmax(preds2, axis=1))
test_acc2 = accuracy_score(y_pred=y_pred2, y_true=y_test)
print(f"BEST VALID SCORE FOR : {clf.best_cost}, {clf2.best_cost}")
print(f"FINAL TEST SCORE FOR : {test_acc}, {test_acc2}")
BEST VALID SCORE FOR : 0.9993484148154181, 1.0
FINAL TEST SCORE FOR : 0.9733333333333334, 0.9483333333333334
y_pred = clf.predict(X_test)
test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)
y_pred2 = clf2.predict(X_test)
test_acc2 = accuracy_score(y_pred=y_pred2, y_true=y_test)
print(f"FINAL TEST SCORE FOR : {test_acc}, {test_acc2}")
FINAL TEST SCORE FOR : 0.9733333333333334, 0.9483333333333334
Save and load Model
# save state dict
saved_filename = clf.save_model('binary_model')
Successfully saved model at binary_model.zip
# define new model and load save parameters
loaded_clf = TabNetClassifier()
loaded_clf.load_model(saved_filename)
Device used : cpu
Device used : cpu
loaded_preds = loaded_clf.predict_proba(X_test)
loaded_y_pred = np.vectorize(preds_mapper.get)(np.argmax(loaded_preds, axis=1))
loaded_test_acc = accuracy_score(y_pred=loaded_y_pred, y_true=y_test)
print(f"FINAL TEST SCORE FOR : {loaded_test_acc}")
FINAL TEST SCORE FOR : 0.9733333333333334
test_acc == loaded_test_acc
True
Global explainability : feat importance summing to 1
clf.feature_importances_
array([0.01141833, 0.10568571, 0.01697402, 0.0599244 , 0.0296432 ,
0.2647944 , 0.01209274, 0.03980985, 0.02885467, 0.00494661,
0.04379065, 0.04271051, 0.20845756, 0.00243536, 0.0882974 ,
0.02919195, 0.01097264])
Local explainability and masks
from matplotlib import pyplot as plt
explain_matrix, masks = clf.explain(X_test)
fig, axs = plt.subplots(1, 5, figsize=(20,20))
for i in range(5):
axs[i].imshow(masks[i][:50])
axs[i].set_title(f"mask {i}")
Export graphs to a html file
fig_path = r'html/tabnet_binary.html'
fig_list[0].write_html(fig_path)
with open(fig_path, 'a') as f:
for fig_i in fig_list[1:]:
f.write(fig_i.to_html(full_html=False, include_plotlyjs='cdn'))