Source code for pytorch_tabnet.sparsemax

from torch import nn
from torch.autograd import Function
import torch.nn.functional as F

import torch

"""
Other possible implementations:
https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py
https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py
https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py
"""


# credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py
def _make_ix_like(input, dim=0):
    d = input.size(dim)
    rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
    view = [1] * input.dim()
    view[0] = -1
    return rho.view(view).transpose(0, dim)


[docs]class SparsemaxFunction(Function): """ An implementation of sparsemax (Martins & Astudillo, 2016). See :cite:`DBLP:journals/corr/MartinsA16` for detailed description. By Ben Peters and Vlad Niculae """
[docs] @staticmethod def forward(ctx, input, dim=-1): """sparsemax: normalizing sparse transform (a la softmax) Parameters ---------- ctx : torch.autograd.function._ContextMethodMixin input : torch.Tensor any shape dim : int dimension along which to apply sparsemax Returns ------- output : torch.Tensor same shape as input """ ctx.dim = dim max_val, _ = input.max(dim=dim, keepdim=True) input -= max_val # same numerical stability trick as for softmax tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) output = torch.clamp(input - tau, min=0) ctx.save_for_backward(supp_size, output) return output
[docs] @staticmethod def backward(ctx, grad_output): supp_size, output = ctx.saved_tensors dim = ctx.dim grad_input = grad_output.clone() grad_input[output == 0] = 0 v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() v_hat = v_hat.unsqueeze(dim) grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) return grad_input, None
@staticmethod def _threshold_and_support(input, dim=-1): """Sparsemax building block: compute the threshold Parameters ---------- input: torch.Tensor any dimension dim : int dimension along which to apply the sparsemax Returns ------- tau : torch.Tensor the threshold value support_size : torch.Tensor """ input_srt, _ = torch.sort(input, descending=True, dim=dim) input_cumsum = input_srt.cumsum(dim) - 1 rhos = _make_ix_like(input, dim) support = rhos * input_srt > input_cumsum support_size = support.sum(dim=dim).unsqueeze(dim) tau = input_cumsum.gather(dim, support_size - 1) tau /= support_size.to(input.dtype) return tau, support_size
sparsemax = SparsemaxFunction.apply
[docs]class Sparsemax(nn.Module): def __init__(self, dim=-1): self.dim = dim super(Sparsemax, self).__init__()
[docs] def forward(self, input): return sparsemax(input, self.dim)
[docs]class Entmax15Function(Function): """ An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See :cite:`https://arxiv.org/abs/1905.05702 for detailed description. Source: https://github.com/deep-spin/entmax """
[docs] @staticmethod def forward(ctx, input, dim=-1): ctx.dim = dim max_val, _ = input.max(dim=dim, keepdim=True) input = input - max_val # same numerical stability trick as for softmax input = input / 2 # divide by 2 to solve actual Entmax tau_star, _ = Entmax15Function._threshold_and_support(input, dim) output = torch.clamp(input - tau_star, min=0) ** 2 ctx.save_for_backward(output) return output
[docs] @staticmethod def backward(ctx, grad_output): Y, = ctx.saved_tensors gppr = Y.sqrt() # = 1 / g'' (Y) dX = grad_output * gppr q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) q = q.unsqueeze(ctx.dim) dX -= q * gppr return dX, None
@staticmethod def _threshold_and_support(input, dim=-1): Xsrt, _ = torch.sort(input, descending=True, dim=dim) rho = _make_ix_like(input, dim) mean = Xsrt.cumsum(dim) / rho mean_sq = (Xsrt ** 2).cumsum(dim) / rho ss = rho * (mean_sq - mean ** 2) delta = (1 - ss) / rho # NOTE this is not exactly the same as in reference algo # Fortunately it seems the clamped values never wrongly # get selected by tau <= sorted_z. Prove this! delta_nz = torch.clamp(delta, 0) tau = mean - torch.sqrt(delta_nz) support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) tau_star = tau.gather(dim, support_size - 1) return tau_star, support_size
[docs]class Entmoid15(Function): """ A highly optimized equivalent of lambda x: Entmax15([x, 0]) """
[docs] @staticmethod def forward(ctx, input): output = Entmoid15._forward(input) ctx.save_for_backward(output) return output
@staticmethod def _forward(input): input, is_pos = abs(input), input >= 0 tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 tau.masked_fill_(tau <= input, 2.0) y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 return torch.where(is_pos, 1 - y_neg, y_neg)
[docs] @staticmethod def backward(ctx, grad_output): return Entmoid15._backward(ctx.saved_tensors[0], grad_output)
@staticmethod def _backward(output, grad_output): gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() grad_input = grad_output * gppr0 q = grad_input / (gppr0 + gppr1) grad_input -= q * gppr0 return grad_input
entmax15 = Entmax15Function.apply entmoid15 = Entmoid15.apply
[docs]class Entmax15(nn.Module): def __init__(self, dim=-1): self.dim = dim super(Entmax15, self).__init__()
[docs] def forward(self, input): return entmax15(input, self.dim)
# Credits were lost... # def _make_ix_like(input, dim=0): # d = input.size(dim) # rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) # view = [1] * input.dim() # view[0] = -1 # return rho.view(view).transpose(0, dim) # # # def _threshold_and_support(input, dim=0): # """Sparsemax building block: compute the threshold # Args: # input: any dimension # dim: dimension along which to apply the sparsemax # Returns: # the threshold value # """ # # input_srt, _ = torch.sort(input, descending=True, dim=dim) # input_cumsum = input_srt.cumsum(dim) - 1 # rhos = _make_ix_like(input, dim) # support = rhos * input_srt > input_cumsum # # support_size = support.sum(dim=dim).unsqueeze(dim) # tau = input_cumsum.gather(dim, support_size - 1) # tau /= support_size.to(input.dtype) # return tau, support_size # # # class SparsemaxFunction(Function): # # @staticmethod # def forward(ctx, input, dim=0): # """sparsemax: normalizing sparse transform (a la softmax) # Parameters: # input (Tensor): any shape # dim: dimension along which to apply sparsemax # Returns: # output (Tensor): same shape as input # """ # ctx.dim = dim # max_val, _ = input.max(dim=dim, keepdim=True) # input -= max_val # same numerical stability trick as for softmax # tau, supp_size = _threshold_and_support(input, dim=dim) # output = torch.clamp(input - tau, min=0) # ctx.save_for_backward(supp_size, output) # return output # # @staticmethod # def backward(ctx, grad_output): # supp_size, output = ctx.saved_tensors # dim = ctx.dim # grad_input = grad_output.clone() # grad_input[output == 0] = 0 # # v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() # v_hat = v_hat.unsqueeze(dim) # grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) # return grad_input, None # # # sparsemax = SparsemaxFunction.apply # # # class Sparsemax(nn.Module): # # def __init__(self, dim=0): # self.dim = dim # super(Sparsemax, self).__init__() # # def forward(self, input): # return sparsemax(input, self.dim)