Source code for pytorch_tabnet.abstract_model

from dataclasses import dataclass, field
from typing import List, Any, Dict
import torch
from torch.nn.utils import clip_grad_norm_
import numpy as np
from scipy.sparse import csc_matrix
from abc import abstractmethod
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
    SparsePredictDataset,
    PredictDataset,
    create_explain_matrix,
    validate_eval_set,
    create_dataloaders,
    define_device,
    ComplexEncoder,
    check_input,
    check_warm_start,
    create_group_matrix,
    check_embedding_parameters
)
from pytorch_tabnet.callbacks import (
    CallbackContainer,
    History,
    EarlyStopping,
    LRSchedulerCallback,
)
from pytorch_tabnet.metrics import MetricContainer, check_metrics
from sklearn.base import BaseEstimator

from torch.utils.data import DataLoader
import io
import json
from pathlib import Path
import shutil
import zipfile
import warnings
import copy
import scipy


[docs]@dataclass class TabModel(BaseEstimator): """ Class for TabNet model.""" n_d: int = 8 n_a: int = 8 n_steps: int = 3 gamma: float = 1.3 cat_idxs: List[int] = field(default_factory=list) cat_dims: List[int] = field(default_factory=list) cat_emb_dim: int = 1 n_independent: int = 2 n_shared: int = 2 epsilon: float = 1e-15 momentum: float = 0.02 lambda_sparse: float = 1e-3 seed: int = 0 clip_value: int = 1 verbose: int = 1 optimizer_fn: Any = torch.optim.Adam optimizer_params: Dict = field(default_factory=lambda: dict(lr=2e-2)) scheduler_fn: Any = None scheduler_params: Dict = field(default_factory=dict) mask_type: str = "sparsemax" input_dim: int = None output_dim: int = None device_name: str = "auto" n_shared_decoder: int = 1 n_indep_decoder: int = 1 grouped_features: List[List[int]] = field(default_factory=list) def __post_init__(self): # These are default values needed for saving model self.batch_size = 1024 self.virtual_batch_size = 128 torch.manual_seed(self.seed) # Defining device self.device = torch.device(define_device(self.device_name)) if self.verbose != 0: warnings.warn(f"Device used : {self.device}") # create deep copies of mutable parameters self.optimizer_fn = copy.deepcopy(self.optimizer_fn) self.scheduler_fn = copy.deepcopy(self.scheduler_fn) updated_params = check_embedding_parameters(self.cat_dims, self.cat_idxs, self.cat_emb_dim) self.cat_dims, self.cat_idxs, self.cat_emb_dim = updated_params def __update__(self, **kwargs): """ Updates parameters. If does not already exists, creates it. Otherwise overwrite with warnings. """ update_list = [ "cat_dims", "cat_emb_dim", "cat_idxs", "input_dim", "mask_type", "n_a", "n_d", "n_independent", "n_shared", "n_steps", "grouped_features", ] for var_name, value in kwargs.items(): if var_name in update_list: try: exec(f"global previous_val; previous_val = self.{var_name}") if previous_val != value: # noqa wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa warnings.warn(wrn_msg) exec(f"self.{var_name} = value") except AttributeError: exec(f"self.{var_name} = value")
[docs] def fit( self, X_train, y_train, eval_set=None, eval_name=None, eval_metric=None, loss_fn=None, weights=0, max_epochs=100, patience=10, batch_size=1024, virtual_batch_size=128, num_workers=0, drop_last=True, callbacks=None, pin_memory=True, from_unsupervised=None, warm_start=False, augmentations=None, compute_importance=True ): """Train a neural network stored in self.network Using train_dataloader for training data and valid_dataloader for validation. Parameters ---------- X_train : np.ndarray Train set y_train : np.array Train targets eval_set : list of tuple List of eval tuple set (X, y). The last one is used for early stopping eval_name : list of str List of eval set names. eval_metric : list of str List of evaluation metrics. The last metric is used for early stopping. loss_fn : callable or None a PyTorch loss function weights : bool or dictionnary 0 for no balancing 1 for automated balancing dict for custom weights per class max_epochs : int Maximum number of epochs during training patience : int Number of consecutive non improving epoch before early stopping batch_size : int Training batch size virtual_batch_size : int Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size) num_workers : int Number of workers used in torch.utils.data.DataLoader drop_last : bool Whether to drop last batch during training callbacks : list of callback function List of custom callbacks pin_memory: bool Whether to set pin_memory to True or False during training from_unsupervised: unsupervised trained model Use a previously self supervised model as starting weights warm_start: bool If True, current model parameters are used to start training compute_importance : bool Whether to compute feature importance """ # update model name self.max_epochs = max_epochs self.patience = patience self.batch_size = batch_size self.virtual_batch_size = virtual_batch_size self.num_workers = num_workers self.drop_last = drop_last self.input_dim = X_train.shape[1] self._stop_training = False self.pin_memory = pin_memory and (self.device.type != "cpu") self.augmentations = augmentations self.compute_importance = compute_importance if self.augmentations is not None: # This ensure reproducibility self.augmentations._set_seed() eval_set = eval_set if eval_set else [] if loss_fn is None: self.loss_fn = self._default_loss else: self.loss_fn = loss_fn check_input(X_train) check_warm_start(warm_start, from_unsupervised) self.update_fit_params( X_train, y_train, eval_set, weights, ) # Validate and reformat eval set depending on training data eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train) train_dataloader, valid_dataloaders = self._construct_loaders( X_train, y_train, eval_set ) if from_unsupervised is not None: # Update parameters to match self pretraining self.__update__(**from_unsupervised.get_params()) if not hasattr(self, "network") or not warm_start: # model has never been fitted before of warm_start is False self._set_network() self._update_network_params() self._set_metrics(eval_metric, eval_names) self._set_optimizer() self._set_callbacks(callbacks) if from_unsupervised is not None: self.load_weights_from_unsupervised(from_unsupervised) warnings.warn("Loading weights from unsupervised pretraining") # Call method on_train_begin for all callbacks self._callback_container.on_train_begin() # Training loop over epochs for epoch_idx in range(self.max_epochs): # Call method on_epoch_begin for all callbacks self._callback_container.on_epoch_begin(epoch_idx) self._train_epoch(train_dataloader) # Apply predict epoch to all eval sets for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders): self._predict_epoch(eval_name, valid_dataloader) # Call method on_epoch_end for all callbacks self._callback_container.on_epoch_end( epoch_idx, logs=self.history.epoch_metrics ) if self._stop_training: break # Call method on_train_end for all callbacks self._callback_container.on_train_end() self.network.eval() if self.compute_importance: # compute feature importance once the best model is defined self.feature_importances_ = self._compute_feature_importances(X_train)
[docs] def predict(self, X): """ Make predictions on a batch (valid) Parameters ---------- X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix` Input data Returns ------- predictions : np.array Predictions of the regression problem """ self.network.eval() if scipy.sparse.issparse(X): dataloader = DataLoader( SparsePredictDataset(X), batch_size=self.batch_size, shuffle=False, ) else: dataloader = DataLoader( PredictDataset(X), batch_size=self.batch_size, shuffle=False, ) results = [] for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() output, M_loss = self.network(data) predictions = output.cpu().detach().numpy() results.append(predictions) res = np.vstack(results) return self.predict_func(res)
[docs] def explain(self, X, normalize=False): """ Return local explanation Parameters ---------- X : tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix` Input data normalize : bool (default False) Wheter to normalize so that sum of features are equal to 1 Returns ------- M_explain : matrix Importance per sample, per columns. masks : matrix Sparse matrix showing attention masks used by network. """ self.network.eval() if scipy.sparse.issparse(X): dataloader = DataLoader( SparsePredictDataset(X), batch_size=self.batch_size, shuffle=False, ) else: dataloader = DataLoader( PredictDataset(X), batch_size=self.batch_size, shuffle=False, ) res_explain = [] for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() M_explain, masks = self.network.forward_masks(data) for key, value in masks.items(): masks[key] = csc_matrix.dot( value.cpu().detach().numpy(), self.reducing_matrix ) original_feat_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix) res_explain.append(original_feat_explain) if batch_nb == 0: res_masks = masks else: for key, value in masks.items(): res_masks[key] = np.vstack([res_masks[key], value]) res_explain = np.vstack(res_explain) if normalize: res_explain /= np.sum(res_explain, axis=1)[:, None] return res_explain, res_masks
[docs] def load_weights_from_unsupervised(self, unsupervised_model): update_state_dict = copy.deepcopy(self.network.state_dict()) for param, weights in unsupervised_model.network.state_dict().items(): if param.startswith("encoder"): # Convert encoder's layers name to match new_param = "tabnet." + param else: new_param = param if self.network.state_dict().get(new_param) is not None: # update only common layers update_state_dict[new_param] = weights self.network.load_state_dict(update_state_dict)
[docs] def load_class_attrs(self, class_attrs): for attr_name, attr_value in class_attrs.items(): setattr(self, attr_name, attr_value)
[docs] def save_model(self, path): """Saving TabNet model in two distinct files. Parameters ---------- path : str Path of the model. Returns ------- str input filepath with ".zip" appended """ saved_params = {} init_params = {} for key, val in self.get_params().items(): if isinstance(val, type): # Don't save torch specific params continue else: init_params[key] = val saved_params["init_params"] = init_params class_attrs = { "preds_mapper": self.preds_mapper } saved_params["class_attrs"] = class_attrs # Create folder Path(path).mkdir(parents=True, exist_ok=True) # Save models params with open(Path(path).joinpath("model_params.json"), "w", encoding="utf8") as f: json.dump(saved_params, f, cls=ComplexEncoder) # Save state_dict torch.save(self.network.state_dict(), Path(path).joinpath("network.pt")) shutil.make_archive(path, "zip", path) shutil.rmtree(path) print(f"Successfully saved model at {path}.zip") return f"{path}.zip"
[docs] def load_model(self, filepath): """Load TabNet model. Parameters ---------- filepath : str Path of the model. """ try: with zipfile.ZipFile(filepath) as z: with z.open("model_params.json") as f: loaded_params = json.load(f) loaded_params["init_params"]["device_name"] = self.device_name with z.open("network.pt") as f: try: saved_state_dict = torch.load(f, map_location=self.device) except io.UnsupportedOperation: # In Python <3.7, the returned file object is not seekable (which at least # some versions of PyTorch require) - so we'll try buffering it in to a # BytesIO instead: saved_state_dict = torch.load( io.BytesIO(f.read()), map_location=self.device, ) except KeyError: raise KeyError("Your zip file is missing at least one component") self.__init__(**loaded_params["init_params"]) self._set_network() self.network.load_state_dict(saved_state_dict) self.network.eval() self.load_class_attrs(loaded_params["class_attrs"]) return
def _train_epoch(self, train_loader): """ Trains one epoch of the network in self.network Parameters ---------- train_loader : a :class: `torch.utils.data.Dataloader` DataLoader with train set """ self.network.train() for batch_idx, (X, y) in enumerate(train_loader): self._callback_container.on_batch_begin(batch_idx) batch_logs = self._train_batch(X, y) self._callback_container.on_batch_end(batch_idx, batch_logs) epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"]} self.history.epoch_metrics.update(epoch_logs) return def _train_batch(self, X, y): """ Trains one batch of data Parameters ---------- X : torch.Tensor Train matrix y : torch.Tensor Target matrix Returns ------- batch_outs : dict Dictionnary with "y": target and "score": prediction scores. batch_logs : dict Dictionnary with "batch_size" and "loss". """ batch_logs = {"batch_size": X.shape[0]} X = X.to(self.device).float() y = y.to(self.device).float() if self.augmentations is not None: X, y = self.augmentations(X, y) for param in self.network.parameters(): param.grad = None output, M_loss = self.network(X) loss = self.compute_loss(output, y) # Add the overall sparsity loss loss = loss - self.lambda_sparse * M_loss # Perform backward pass and optimization loss.backward() if self.clip_value: clip_grad_norm_(self.network.parameters(), self.clip_value) self._optimizer.step() batch_logs["loss"] = loss.cpu().detach().numpy().item() return batch_logs def _predict_epoch(self, name, loader): """ Predict an epoch and update metrics. Parameters ---------- name : str Name of the validation set loader : torch.utils.data.Dataloader DataLoader with validation set """ # Setting network on evaluation mode self.network.eval() list_y_true = [] list_y_score = [] # Main loop for batch_idx, (X, y) in enumerate(loader): scores = self._predict_batch(X) list_y_true.append(y) list_y_score.append(scores) y_true, scores = self.stack_batches(list_y_true, list_y_score) metrics_logs = self._metric_container_dict[name](y_true, scores) self.network.train() self.history.epoch_metrics.update(metrics_logs) return def _predict_batch(self, X): """ Predict one batch of data. Parameters ---------- X : torch.Tensor Owned products Returns ------- np.array model scores """ X = X.to(self.device).float() # compute model output scores, _ = self.network(X) if isinstance(scores, list): scores = [x.cpu().detach().numpy() for x in scores] else: scores = scores.cpu().detach().numpy() return scores def _set_network(self): """Setup the network and explain matrix.""" torch.manual_seed(self.seed) self.group_matrix = create_group_matrix(self.grouped_features, self.input_dim) self.network = tab_network.TabNet( self.input_dim, self.output_dim, n_d=self.n_d, n_a=self.n_a, n_steps=self.n_steps, gamma=self.gamma, cat_idxs=self.cat_idxs, cat_dims=self.cat_dims, cat_emb_dim=self.cat_emb_dim, n_independent=self.n_independent, n_shared=self.n_shared, epsilon=self.epsilon, virtual_batch_size=self.virtual_batch_size, momentum=self.momentum, mask_type=self.mask_type, group_attention_matrix=self.group_matrix.to(self.device), ).to(self.device) self.reducing_matrix = create_explain_matrix( self.network.input_dim, self.network.cat_emb_dim, self.network.cat_idxs, self.network.post_embed_dim, ) def _set_metrics(self, metrics, eval_names): """Set attributes relative to the metrics. Parameters ---------- metrics : list of str List of eval metric names. eval_names : list of str List of eval set names. """ metrics = metrics or [self._default_metric] metrics = check_metrics(metrics) # Set metric container for each sets self._metric_container_dict = {} for name in eval_names: self._metric_container_dict.update( {name: MetricContainer(metrics, prefix=f"{name}_")} ) self._metrics = [] self._metrics_names = [] for _, metric_container in self._metric_container_dict.items(): self._metrics.extend(metric_container.metrics) self._metrics_names.extend(metric_container.names) # Early stopping metric is the last eval metric self.early_stopping_metric = ( self._metrics_names[-1] if len(self._metrics_names) > 0 else None ) def _set_callbacks(self, custom_callbacks): """Setup the callbacks functions. Parameters ---------- custom_callbacks : list of func List of callback functions. """ # Setup default callbacks history, early stopping and scheduler callbacks = [] self.history = History(self, verbose=self.verbose) callbacks.append(self.history) if (self.early_stopping_metric is not None) and (self.patience > 0): early_stopping = EarlyStopping( early_stopping_metric=self.early_stopping_metric, is_maximize=( self._metrics[-1]._maximize if len(self._metrics) > 0 else None ), patience=self.patience, ) callbacks.append(early_stopping) else: wrn_msg = "No early stopping will be performed, last training weights will be used." warnings.warn(wrn_msg) if self.scheduler_fn is not None: # Add LR Scheduler call_back is_batch_level = self.scheduler_params.pop("is_batch_level", False) scheduler = LRSchedulerCallback( scheduler_fn=self.scheduler_fn, scheduler_params=self.scheduler_params, optimizer=self._optimizer, early_stopping_metric=self.early_stopping_metric, is_batch_level=is_batch_level, ) callbacks.append(scheduler) if custom_callbacks: callbacks.extend(custom_callbacks) self._callback_container = CallbackContainer(callbacks) self._callback_container.set_trainer(self) def _set_optimizer(self): """Setup optimizer.""" self._optimizer = self.optimizer_fn( self.network.parameters(), **self.optimizer_params ) def _construct_loaders(self, X_train, y_train, eval_set): """Generate dataloaders for train and eval set. Parameters ---------- X_train : np.array Train set. y_train : np.array Train targets. eval_set : list of tuple List of eval tuple set (X, y). Returns ------- train_dataloader : `torch.utils.data.Dataloader` Training dataloader. valid_dataloaders : list of `torch.utils.data.Dataloader` List of validation dataloaders. """ # all weights are not allowed for this type of model y_train_mapped = self.prepare_target(y_train) for i, (X, y) in enumerate(eval_set): y_mapped = self.prepare_target(y) eval_set[i] = (X, y_mapped) train_dataloader, valid_dataloaders = create_dataloaders( X_train, y_train_mapped, eval_set, self.updated_weights, self.batch_size, self.num_workers, self.drop_last, self.pin_memory, ) return train_dataloader, valid_dataloaders def _compute_feature_importances(self, X): """Compute global feature importance. Parameters ---------- loader : `torch.utils.data.Dataloader` Pytorch dataloader. """ M_explain, _ = self.explain(X, normalize=False) sum_explain = M_explain.sum(axis=0) feature_importances_ = sum_explain / np.sum(sum_explain) return feature_importances_ def _update_network_params(self): self.network.virtual_batch_size = self.virtual_batch_size
[docs] @abstractmethod def update_fit_params(self, X_train, y_train, eval_set, weights): """ Set attributes relative to fit function. Parameters ---------- X_train : np.ndarray Train set y_train : np.array Train targets eval_set : list of tuple List of eval tuple set (X, y). weights : bool or dictionnary 0 for no balancing 1 for automated balancing """ raise NotImplementedError( "users must define update_fit_params to use this base class" )
[docs] @abstractmethod def compute_loss(self, y_score, y_true): """ Compute the loss. Parameters ---------- y_score : a :tensor: `torch.Tensor` Score matrix y_true : a :tensor: `torch.Tensor` Target matrix Returns ------- float Loss value """ raise NotImplementedError( "users must define compute_loss to use this base class" )
[docs] @abstractmethod def prepare_target(self, y): """ Prepare target before training. Parameters ---------- y : a :tensor: `torch.Tensor` Target matrix. Returns ------- `torch.Tensor` Converted target matrix. """ raise NotImplementedError( "users must define prepare_target to use this base class" )