from torch.utils.data import DataLoader
from pytorch_tabnet.utils import (
create_sampler,
SparsePredictDataset,
PredictDataset,
check_input
)
import scipy
[docs]def create_dataloaders(
X_train, eval_set, weights, batch_size, num_workers, drop_last, pin_memory
):
"""
Create dataloaders with or without subsampling depending on weights and balanced.
Parameters
----------
X_train : np.ndarray or scipy.sparse.csr_matrix
Training data
eval_set : list of np.array (for Xs and ys) or scipy.sparse.csr_matrix (for Xs)
List of eval sets
weights : either 0, 1, dict or iterable
if 0 (default) : no weights will be applied
if 1 : classification only, will balanced class with inverse frequency
if dict : keys are corresponding class values are sample weights
if iterable : list or np array must be of length equal to nb elements
in the training set
batch_size : int
how many samples per batch to load
num_workers : int
how many subprocesses to use for data loading. 0 means that the data
will be loaded in the main process
drop_last : bool
set to True to drop the last incomplete batch, if the dataset size is not
divisible by the batch size. If False and the size of dataset is not
divisible by the batch size, then the last batch will be smaller
pin_memory : bool
Whether to pin GPU memory during training
Returns
-------
train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader
Training and validation dataloaders
"""
need_shuffle, sampler = create_sampler(weights, X_train)
if scipy.sparse.issparse(X_train):
train_dataloader = DataLoader(
SparsePredictDataset(X_train),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
)
else:
train_dataloader = DataLoader(
PredictDataset(X_train),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
)
valid_dataloaders = []
for X in eval_set:
if scipy.sparse.issparse(X):
valid_dataloaders.append(
DataLoader(
SparsePredictDataset(X),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
)
)
else:
valid_dataloaders.append(
DataLoader(
PredictDataset(X),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
)
)
return train_dataloader, valid_dataloaders
[docs]def validate_eval_set(eval_set, eval_name, X_train):
"""Check if the shapes of eval_set are compatible with X_train.
Parameters
----------
eval_set : List of numpy array
The list evaluation set.
The last one is used for early stopping
X_train : np.ndarray
Train owned products
Returns
-------
eval_names : list of str
Validated list of eval_names.
"""
eval_names = eval_name or [f"val_{i}" for i in range(len(eval_set))]
assert len(eval_set) == len(
eval_names
), "eval_set and eval_name have not the same length"
for set_nb, X in enumerate(eval_set):
check_input(X)
msg = (
f"Number of columns is different between eval set {set_nb}"
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
)
assert X.shape[1] == X_train.shape[1], msg
return eval_names