Coverage for src/sensai/torch/torch_enums.py: 70%
57 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from enum import Enum
2import functools
3from typing import Optional, Callable, Union
5from torch.nn import functional as F
8class ActivationFunction(Enum):
9 NONE = "none"
10 SIGMOID = "sigmoid"
11 RELU = "relu"
12 TANH = "tanh"
13 LOG_SOFTMAX = "log_softmax"
14 SOFTMAX = "softmax"
16 @classmethod
17 def from_name(cls, name: str) -> "ActivationFunction":
18 for item in cls:
19 if item.get_name() == name:
20 return item
21 raise ValueError(f"No function found for name '{name}'")
23 def get_torch_function(self) -> Optional[Callable]:
24 return {
25 ActivationFunction.NONE: None,
26 ActivationFunction.SIGMOID: F.sigmoid,
27 ActivationFunction.RELU: F.relu,
28 ActivationFunction.TANH: F.tanh,
29 ActivationFunction.LOG_SOFTMAX: functools.partial(F.log_softmax, dim=1),
30 ActivationFunction.SOFTMAX: functools.partial(F.softmax, dim=1)
31 }[self]
33 def get_name(self) -> str:
34 return self.value
36 @classmethod
37 def torch_function_from_any(cls, f: Union[str, "ActivationFunction", Callable, None]) -> Optional[Callable]:
38 """
39 Gets the torch activation for the given argument
41 :param f: either an instance of ActivationFunction, the name of a function from torch.nn.functional or an actual function
42 :return: a function that can be applied to tensors (or None)
43 """
44 if f is None:
45 return None
46 elif isinstance(f, str):
47 try:
48 return cls.from_name(f).get_torch_function()
49 except ValueError:
50 return getattr(F, f)
51 elif isinstance(f, ActivationFunction):
52 return f.get_torch_function()
53 elif callable(f):
54 return f
55 else:
56 raise ValueError(f"Could not determine torch function from {f} of type {type(f)}")
59class ClassificationOutputMode(Enum):
60 PROBABILITIES = "probabilities"
61 LOG_PROBABILITIES = "log_probabilities"
62 UNNORMALISED_LOG_PROBABILITIES = "unnormalised_log_probabilities"
64 @classmethod
65 def for_activation_fn(cls, fn: Optional[Union[Callable, ActivationFunction]]):
66 if isinstance(fn, ActivationFunction):
67 fn = fn.get_torch_function()
68 if fn is None:
69 return cls.UNNORMALISED_LOG_PROBABILITIES
70 if not callable(fn):
71 raise ValueError(fn)
72 if isinstance(fn, functools.partial):
73 fn = fn.func
74 name = fn.__name__
75 if name in ("sigmoid", "relu", "tanh"):
76 raise ValueError(f"The activation function {fn} is not suitable as an output activation function for classification")
77 elif name in ("log_softmax",):
78 return cls.LOG_PROBABILITIES
79 elif name in ("softmax",):
80 return cls.PROBABILITIES
81 else:
82 raise ValueError(f"Unhandled function {fn}")