Source code for sensai.torch.torch_enums

from enum import Enum
import functools
from typing import Optional, Callable, Union

from torch.nn import functional as F


[docs]class ActivationFunction(Enum): NONE = "none" SIGMOID = "sigmoid" RELU = "relu" TANH = "tanh" LOG_SOFTMAX = "log_softmax" SOFTMAX = "softmax"
[docs] @classmethod def from_name(cls, name: str) -> "ActivationFunction": for item in cls: if item.get_name() == name: return item raise ValueError(f"No function found for name '{name}'")
[docs] def get_torch_function(self) -> Optional[Callable]: return { ActivationFunction.NONE: None, ActivationFunction.SIGMOID: F.sigmoid, ActivationFunction.RELU: F.relu, ActivationFunction.TANH: F.tanh, ActivationFunction.LOG_SOFTMAX: functools.partial(F.log_softmax, dim=1), ActivationFunction.SOFTMAX: functools.partial(F.softmax, dim=1) }[self]
[docs] def get_name(self) -> str: return self.value
[docs] @classmethod def torch_function_from_any(cls, f: Union[str, "ActivationFunction", Callable, None]) -> Optional[Callable]: """ Gets the torch activation for the given argument :param f: either an instance of ActivationFunction, the name of a function from torch.nn.functional or an actual function :return: a function that can be applied to tensors (or None) """ if f is None: return None elif isinstance(f, str): try: return cls.from_name(f).get_torch_function() except ValueError: return getattr(F, f) elif isinstance(f, ActivationFunction): return f.get_torch_function() elif callable(f): return f else: raise ValueError(f"Could not determine torch function from {f} of type {type(f)}")
[docs]class ClassificationOutputMode(Enum): PROBABILITIES = "probabilities" LOG_PROBABILITIES = "log_probabilities" UNNORMALISED_LOG_PROBABILITIES = "unnormalised_log_probabilities"
[docs] @classmethod def for_activation_fn(cls, fn: Optional[Union[Callable, ActivationFunction]]): if isinstance(fn, ActivationFunction): fn = fn.get_torch_function() if fn is None: return cls.UNNORMALISED_LOG_PROBABILITIES if not callable(fn): raise ValueError(fn) if isinstance(fn, functools.partial): fn = fn.func name = fn.__name__ if name in ("sigmoid", "relu", "tanh"): raise ValueError(f"The activation function {fn} is not suitable as an output activation function for classification") elif name in ("log_softmax",): return cls.LOG_PROBABILITIES elif name in ("softmax",): return cls.PROBABILITIES else: raise ValueError(f"Unhandled function {fn}")