Coverage for src/sensai/pytorch_lightning/pl_models.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1import numpy as np
2from pytorch_lightning import Trainer, LightningModule
3from torch import tensor
5from .. import tensor_model as tm
6from ..data import InputOutputArrays, DataSplitterFractional
9def _fit_model_with_trainer(model: LightningModule, trainer: Trainer, io_data,
10 batch_size: int, splitter: DataSplitterFractional = None):
11 if splitter is not None:
12 train_io_data, validation_io_data = splitter.split(io_data)
13 train_data_loader = train_io_data.toTorchDataLoader(batchSize=batch_size)
14 val_data_loader = validation_io_data.toTorchDataLoader(batchSize=batch_size)
15 else:
16 train_data_loader = io_data.to_torch_data_loader(batch_size=batch_size)
17 val_data_loader = None
18 trainer.fit(model, train_data_loader, val_dataloaders=val_data_loader)
21class PLWrappedModel:
22 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32):
23 if not 0 <= validation_fraction <= 1:
24 raise ValueError(f"Invalid validationFraction: {validation_fraction}. Has to be in interval [0, 1]")
25 self.trainer = trainer
26 self.model = model
27 self.validationFraction = validation_fraction
28 self.shuffle = shuffle
29 self.batchSize = batch_size
31 def fit(self, x: np.ndarray, y: np.ndarray):
32 io_data = InputOutputArrays(x, y)
33 splitter = DataSplitterFractional(1 - self.validationFraction, shuffle=self.shuffle)
34 _fit_model_with_trainer(self.model, self.trainer, io_data, self.batchSize, splitter=splitter)
36 def predict(self, x: np.ndarray) -> np.ndarray:
37 x = tensor(x)
38 return self.model(x).detach().cpu().numpy()
41# noinspection DuplicatedCode
42class PLTensorToScalarClassificationModel(tm.TensorToScalarClassificationModel):
43 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=64,
44 check_input_shape=True, check_input_columns=True):
45 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns)
46 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle,
47 batch_size=batch_size)
49 def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray:
50 return self.wrapped_model.predict(x)
52 def _fit_to_array(self, x: np.ndarray, y: np.ndarray):
53 self.wrapped_model.fit(x, y)
56# noinspection DuplicatedCode
57class PLTensorToScalarRegressionModel(tm.TensorToScalarRegressionModel):
58 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32,
59 check_input_shape=True, check_input_columns=True):
60 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns)
61 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle,
62 batch_size=batch_size)
64 def _predict_array(self, x: np.ndarray) -> np.ndarray:
65 return self.wrapped_model.predict(x)
67 def _fit_to_array(self, x: np.ndarray, y: np.ndarray):
68 self.wrapped_model.fit(x, y)
71# noinspection DuplicatedCode
72class PLTensorToTensorClassificationModel(tm.TensorToTensorClassificationModel):
73 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32,
74 check_input_shape=True, check_input_columns=True):
75 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns)
76 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle,
77 batch_size=batch_size)
79 def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray:
80 return self.wrapped_model.predict(x)
82 def _fit_to_array(self, x: np.ndarray, y: np.ndarray):
83 self.wrapped_model.fit(x, y)
86# noinspection DuplicatedCode
87class PLTensorToTensorRegressionModel(tm.TensorToTensorRegressionModel):
88 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32,
89 check_input_shape=True, check_input_columns=True):
90 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns)
91 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle,
92 batch_size=batch_size)
94 def _predict_array(self, x: np.ndarray) -> np.ndarray:
95 return self.wrapped_model.predict(x)
97 def _fit_to_array(self, x: np.ndarray, y: np.ndarray):
98 self.wrapped_model.fit(x, y)