Coverage for src/sensai/torch/torch_models/mlp/mlp_modules.py: 91%
34 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from typing import Callable, Optional, Sequence
3import torch
4from torch import nn
6from ...torch_base import MCDropoutCapableNNModule
7from ....util.string import object_repr, function_name
10class MultiLayerPerceptron(MCDropoutCapableNNModule):
11 def __init__(self, input_dim: float, output_dim: float, hidden_dims: Sequence[int],
12 hid_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
13 output_activation_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = torch.sigmoid,
14 p_dropout: Optional[float] = None):
15 super().__init__()
16 self.inputDim = input_dim
17 self.outputDim = output_dim
18 self.hiddenDims = hidden_dims
19 self.hidActivationFn = hid_activation_fn
20 self.outputActivationFn = output_activation_fn
21 self.pDropout = p_dropout
22 self.layers = nn.ModuleList()
23 if p_dropout is not None:
24 self.dropout = nn.Dropout(p=p_dropout)
25 else:
26 self.dropout = None
27 prev_dim = input_dim
28 for dim in [*hidden_dims, output_dim]:
29 self.layers.append(nn.Linear(prev_dim, dim))
30 prev_dim = dim
32 def __str__(self):
33 return object_repr(self, dict(inputDim=self.inputDim, outputDim=self.outputDim, hiddenDims=self.hiddenDims,
34 hidActivationFn=function_name(self.hidActivationFn) if self.hidActivationFn is not None else None,
35 outputActivationFn=function_name(self.outputActivationFn) if self.outputActivationFn is not None else None,
36 pDropout=self.pDropout))
38 def forward(self, x):
39 for i, layer in enumerate(self.layers):
40 is_last = i+1 == len(self.layers)
41 x = layer(x)
42 if not is_last and self.dropout is not None:
43 x = self.dropout(x)
44 activation = self.hidActivationFn if not is_last else self.outputActivationFn
45 if activation is not None:
46 x = activation(x)
47 return x