Coverage for src/sensai/torch/torch_models/mlp/mlp_modules.py: 91%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from typing import Callable, Optional, Sequence 

2 

3import torch 

4from torch import nn 

5 

6from ...torch_base import MCDropoutCapableNNModule 

7from ....util.string import object_repr, function_name 

8 

9 

10class MultiLayerPerceptron(MCDropoutCapableNNModule): 

11 def __init__(self, input_dim: float, output_dim: float, hidden_dims: Sequence[int], 

12 hid_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, 

13 output_activation_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = torch.sigmoid, 

14 p_dropout: Optional[float] = None): 

15 super().__init__() 

16 self.inputDim = input_dim 

17 self.outputDim = output_dim 

18 self.hiddenDims = hidden_dims 

19 self.hidActivationFn = hid_activation_fn 

20 self.outputActivationFn = output_activation_fn 

21 self.pDropout = p_dropout 

22 self.layers = nn.ModuleList() 

23 if p_dropout is not None: 

24 self.dropout = nn.Dropout(p=p_dropout) 

25 else: 

26 self.dropout = None 

27 prev_dim = input_dim 

28 for dim in [*hidden_dims, output_dim]: 

29 self.layers.append(nn.Linear(prev_dim, dim)) 

30 prev_dim = dim 

31 

32 def __str__(self): 

33 return object_repr(self, dict(inputDim=self.inputDim, outputDim=self.outputDim, hiddenDims=self.hiddenDims, 

34 hidActivationFn=function_name(self.hidActivationFn) if self.hidActivationFn is not None else None, 

35 outputActivationFn=function_name(self.outputActivationFn) if self.outputActivationFn is not None else None, 

36 pDropout=self.pDropout)) 

37 

38 def forward(self, x): 

39 for i, layer in enumerate(self.layers): 

40 is_last = i+1 == len(self.layers) 

41 x = layer(x) 

42 if not is_last and self.dropout is not None: 

43 x = self.dropout(x) 

44 activation = self.hidActivationFn if not is_last else self.outputActivationFn 

45 if activation is not None: 

46 x = activation(x) 

47 return x