Coverage for src/sensai/torch/torch_enums.py: 70%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from enum import Enum 

2import functools 

3from typing import Optional, Callable, Union 

4 

5from torch.nn import functional as F 

6 

7 

8class ActivationFunction(Enum): 

9 NONE = "none" 

10 SIGMOID = "sigmoid" 

11 RELU = "relu" 

12 TANH = "tanh" 

13 LOG_SOFTMAX = "log_softmax" 

14 SOFTMAX = "softmax" 

15 

16 @classmethod 

17 def from_name(cls, name: str) -> "ActivationFunction": 

18 for item in cls: 

19 if item.get_name() == name: 

20 return item 

21 raise ValueError(f"No function found for name '{name}'") 

22 

23 def get_torch_function(self) -> Optional[Callable]: 

24 return { 

25 ActivationFunction.NONE: None, 

26 ActivationFunction.SIGMOID: F.sigmoid, 

27 ActivationFunction.RELU: F.relu, 

28 ActivationFunction.TANH: F.tanh, 

29 ActivationFunction.LOG_SOFTMAX: functools.partial(F.log_softmax, dim=1), 

30 ActivationFunction.SOFTMAX: functools.partial(F.softmax, dim=1) 

31 }[self] 

32 

33 def get_name(self) -> str: 

34 return self.value 

35 

36 @classmethod 

37 def torch_function_from_any(cls, f: Union[str, "ActivationFunction", Callable, None]) -> Optional[Callable]: 

38 """ 

39 Gets the torch activation for the given argument 

40 

41 :param f: either an instance of ActivationFunction, the name of a function from torch.nn.functional or an actual function 

42 :return: a function that can be applied to tensors (or None) 

43 """ 

44 if f is None: 

45 return None 

46 elif isinstance(f, str): 

47 try: 

48 return cls.from_name(f).get_torch_function() 

49 except ValueError: 

50 return getattr(F, f) 

51 elif isinstance(f, ActivationFunction): 

52 return f.get_torch_function() 

53 elif callable(f): 

54 return f 

55 else: 

56 raise ValueError(f"Could not determine torch function from {f} of type {type(f)}") 

57 

58 

59class ClassificationOutputMode(Enum): 

60 PROBABILITIES = "probabilities" 

61 LOG_PROBABILITIES = "log_probabilities" 

62 UNNORMALISED_LOG_PROBABILITIES = "unnormalised_log_probabilities" 

63 

64 @classmethod 

65 def for_activation_fn(cls, fn: Optional[Union[Callable, ActivationFunction]]): 

66 if isinstance(fn, ActivationFunction): 

67 fn = fn.get_torch_function() 

68 if fn is None: 

69 return cls.UNNORMALISED_LOG_PROBABILITIES 

70 if not callable(fn): 

71 raise ValueError(fn) 

72 if isinstance(fn, functools.partial): 

73 fn = fn.func 

74 name = fn.__name__ 

75 if name in ("sigmoid", "relu", "tanh"): 

76 raise ValueError(f"The activation function {fn} is not suitable as an output activation function for classification") 

77 elif name in ("log_softmax",): 

78 return cls.LOG_PROBABILITIES 

79 elif name in ("softmax",): 

80 return cls.PROBABILITIES 

81 else: 

82 raise ValueError(f"Unhandled function {fn}")