skills/tabular/tabnet-sklearn-wrapper/SKILL.md
Wrap PyTorch TabNet in a scikit-learn BaseEstimator with built-in imputation and early stopping for use in VotingRegressor ensembles
npx skillsauth add wenmin-wu/ds-skills tabular-tabnet-sklearn-wrapperInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
TabNet's native API is incompatible with scikit-learn's VotingRegressor/VotingClassifier. Wrapping it in a BaseEstimator with internal imputation, validation split, and early stopping lets you ensemble TabNet alongside LightGBM/XGBoost/CatBoost using standard sklearn patterns.
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from pytorch_tabnet.tab_model import TabNetRegressor
class TabNetWrapper(BaseEstimator, RegressorMixin):
def __init__(self, n_d=24, n_a=24, n_steps=3, lr=0.02,
patience=20, max_epochs=200, batch_size=1024):
self.n_d = n_d; self.n_a = n_a; self.n_steps = n_steps
self.lr = lr; self.patience = patience
self.max_epochs = max_epochs; self.batch_size = batch_size
self.imputer = SimpleImputer(strategy='median')
def fit(self, X, y):
X_imp = self.imputer.fit_transform(X)
if hasattr(y, 'values'): y = y.values
Xt, Xv, yt, yv = train_test_split(X_imp, y, test_size=0.2)
self.model_ = TabNetRegressor(
n_d=self.n_d, n_a=self.n_a, n_steps=self.n_steps,
optimizer_params=dict(lr=self.lr))
self.model_.fit(Xt, yt.reshape(-1,1),
eval_set=[(Xv, yv.reshape(-1,1))],
max_epochs=self.max_epochs, patience=self.patience,
batch_size=self.batch_size, drop_last=False)
return self
def predict(self, X):
return self.model_.predict(self.imputer.transform(X)).flatten()
TabNetWrapper(BaseEstimator, RegressorMixin) with TabNet hyperparams as __init__ argsfit(): impute → split for early stopping → train TabNetpredict(): impute → predict → flattenVotingRegressor(estimators=[('lgbm', lgbm), ('tabnet', TabNetWrapper())])eval_set externally__init__: required for sklearn clone() and GridSearchCV compatibilitydrop_last=False: prevents silent sample loss on small datasetsdata-ai
Scaled Pinball Loss (SPL) metric for evaluating quantile forecasts, normalized by mean absolute successive differences of training data
data-ai
Walk backward through a time series and multiplicatively rescale segments when jumps exceed a fraction of the running mean to correct data collection anomalies
testing
Transform forecasting target to next/current ratio minus one so that optimizing MAE or squared error implicitly minimizes SMAPE
tools
Convert point forecasts to prediction intervals by scaling with logit-transformed quantile ratios passed through a Normal CDF