import time
import datetime
import copy
import numpy as np
from dataclasses import dataclass, field
from typing import List, Any
import warnings
[docs]class Callback:
"""
Abstract base class used to build new callbacks.
"""
def __init__(self):
pass
[docs] def set_params(self, params):
self.params = params
[docs] def set_trainer(self, model):
self.trainer = model
[docs] def on_epoch_begin(self, epoch, logs=None):
pass
[docs] def on_epoch_end(self, epoch, logs=None):
pass
[docs] def on_batch_begin(self, batch, logs=None):
pass
[docs] def on_batch_end(self, batch, logs=None):
pass
[docs] def on_train_begin(self, logs=None):
pass
[docs] def on_train_end(self, logs=None):
pass
[docs]@dataclass
class CallbackContainer:
"""
Container holding a list of callbacks.
"""
callbacks: List[Callback] = field(default_factory=list)
[docs] def append(self, callback):
self.callbacks.append(callback)
[docs] def set_params(self, params):
for callback in self.callbacks:
callback.set_params(params)
[docs] def set_trainer(self, trainer):
self.trainer = trainer
for callback in self.callbacks:
callback.set_trainer(trainer)
[docs] def on_epoch_begin(self, epoch, logs=None):
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
[docs] def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
[docs] def on_batch_begin(self, batch, logs=None):
logs = logs or {}
for callback in self.callbacks:
callback.on_batch_begin(batch, logs)
[docs] def on_batch_end(self, batch, logs=None):
logs = logs or {}
for callback in self.callbacks:
callback.on_batch_end(batch, logs)
[docs] def on_train_begin(self, logs=None):
logs = logs or {}
logs["start_time"] = time.time()
for callback in self.callbacks:
callback.on_train_begin(logs)
[docs] def on_train_end(self, logs=None):
logs = logs or {}
for callback in self.callbacks:
callback.on_train_end(logs)
[docs]@dataclass
class EarlyStopping(Callback):
"""EarlyStopping callback to exit the training loop if early_stopping_metric
does not improve by a certain amount for a certain
number of epochs.
Parameters
---------
early_stopping_metric : str
Early stopping metric name
is_maximize : bool
Whether to maximize or not early_stopping_metric
tol : float
minimum change in monitored value to qualify as improvement.
This number should be positive.
patience : integer
number of epochs to wait for improvement before terminating.
the counter be reset after each improvement
"""
early_stopping_metric: str
is_maximize: bool
tol: float = 0.0
patience: int = 5
def __post_init__(self):
self.best_epoch = 0
self.stopped_epoch = 0
self.wait = 0
self.best_weights = None
self.best_loss = np.inf
if self.is_maximize:
self.best_loss = -self.best_loss
super().__init__()
[docs] def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get(self.early_stopping_metric)
if current_loss is None:
return
loss_change = current_loss - self.best_loss
max_improved = self.is_maximize and loss_change > self.tol
min_improved = (not self.is_maximize) and (-loss_change > self.tol)
if max_improved or min_improved:
self.best_loss = current_loss
self.best_epoch = epoch
self.wait = 1
self.best_weights = copy.deepcopy(self.trainer.network.state_dict())
else:
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.trainer._stop_training = True
self.wait += 1
[docs] def on_train_end(self, logs=None):
self.trainer.best_epoch = self.best_epoch
self.trainer.best_cost = self.best_loss
if self.best_weights is not None:
self.trainer.network.load_state_dict(self.best_weights)
if self.stopped_epoch > 0:
msg = f"\nEarly stopping occurred at epoch {self.stopped_epoch}"
msg += (
f" with best_epoch = {self.best_epoch} and "
+ f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}"
)
print(msg)
else:
msg = (
f"Stop training because you reached max_epochs = {self.trainer.max_epochs}"
+ f" with best_epoch = {self.best_epoch} and "
+ f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}"
)
print(msg)
wrn_msg = "Best weights from best epoch are automatically used!"
warnings.warn(wrn_msg)
[docs]@dataclass
class History(Callback):
"""Callback that records events into a `History` object.
This callback is automatically applied to
every SuperModule.
Parameters
---------
trainer : DeepRecoModel
Model class to train
verbose : int
Print results every verbose iteration
"""
trainer: Any
verbose: int = 1
def __post_init__(self):
super().__init__()
self.samples_seen = 0.0
self.total_time = 0.0
[docs] def on_train_begin(self, logs=None):
self.history = {"loss": []}
self.history.update({"lr": []})
self.history.update({name: [] for name in self.trainer._metrics_names})
self.start_time = logs["start_time"]
self.epoch_loss = 0.0
[docs] def on_epoch_begin(self, epoch, logs=None):
self.epoch_metrics = {"loss": 0.0}
self.samples_seen = 0.0
[docs] def on_epoch_end(self, epoch, logs=None):
self.epoch_metrics["loss"] = self.epoch_loss
for metric_name, metric_value in self.epoch_metrics.items():
self.history[metric_name].append(metric_value)
if self.verbose == 0:
return
if epoch % self.verbose != 0:
return
msg = f"epoch {epoch:<3}"
for metric_name, metric_value in self.epoch_metrics.items():
if metric_name != "lr":
msg += f"| {metric_name:<3}: {np.round(metric_value, 5):<8}"
self.total_time = int(time.time() - self.start_time)
msg += f"| {str(datetime.timedelta(seconds=self.total_time)) + 's':<6}"
print(msg)
[docs] def on_batch_end(self, batch, logs=None):
batch_size = logs["batch_size"]
self.epoch_loss = (
self.samples_seen * self.epoch_loss + batch_size * logs["loss"]
) / (self.samples_seen + batch_size)
self.samples_seen += batch_size
def __getitem__(self, name):
return self.history[name]
def __repr__(self):
return str(self.history)
def __str__(self):
return str(self.history)
[docs]@dataclass
class LRSchedulerCallback(Callback):
"""Wrapper for most torch scheduler functions.
Parameters
---------
scheduler_fn : torch.optim.lr_scheduler
Torch scheduling class
scheduler_params : dict
Dictionnary containing all parameters for the scheduler_fn
is_batch_level : bool (default = False)
If set to False : lr updates will happen at every epoch
If set to True : lr updates happen at every batch
Set this to True for OneCycleLR for example
"""
scheduler_fn: Any
optimizer: Any
scheduler_params: dict
early_stopping_metric: str
is_batch_level: bool = False
def __post_init__(
self,
):
self.is_metric_related = hasattr(self.scheduler_fn, "is_better")
self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params)
super().__init__()
[docs] def on_batch_end(self, batch, logs=None):
if self.is_batch_level:
self.scheduler.step()
else:
pass
[docs] def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get(self.early_stopping_metric)
if current_loss is None:
return
if self.is_batch_level:
pass
else:
if self.is_metric_related:
self.scheduler.step(current_loss)
else:
self.scheduler.step()