Coverage for src/sensai/pytorch_lightning/pl_models.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import numpy as np 

2from pytorch_lightning import Trainer, LightningModule 

3from torch import tensor 

4 

5from .. import tensor_model as tm 

6from ..data import InputOutputArrays, DataSplitterFractional 

7 

8 

9def _fit_model_with_trainer(model: LightningModule, trainer: Trainer, io_data, 

10 batch_size: int, splitter: DataSplitterFractional = None): 

11 if splitter is not None: 

12 train_io_data, validation_io_data = splitter.split(io_data) 

13 train_data_loader = train_io_data.toTorchDataLoader(batchSize=batch_size) 

14 val_data_loader = validation_io_data.toTorchDataLoader(batchSize=batch_size) 

15 else: 

16 train_data_loader = io_data.to_torch_data_loader(batch_size=batch_size) 

17 val_data_loader = None 

18 trainer.fit(model, train_data_loader, val_dataloaders=val_data_loader) 

19 

20 

21class PLWrappedModel: 

22 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32): 

23 if not 0 <= validation_fraction <= 1: 

24 raise ValueError(f"Invalid validationFraction: {validation_fraction}. Has to be in interval [0, 1]") 

25 self.trainer = trainer 

26 self.model = model 

27 self.validationFraction = validation_fraction 

28 self.shuffle = shuffle 

29 self.batchSize = batch_size 

30 

31 def fit(self, x: np.ndarray, y: np.ndarray): 

32 io_data = InputOutputArrays(x, y) 

33 splitter = DataSplitterFractional(1 - self.validationFraction, shuffle=self.shuffle) 

34 _fit_model_with_trainer(self.model, self.trainer, io_data, self.batchSize, splitter=splitter) 

35 

36 def predict(self, x: np.ndarray) -> np.ndarray: 

37 x = tensor(x) 

38 return self.model(x).detach().cpu().numpy() 

39 

40 

41# noinspection DuplicatedCode 

42class PLTensorToScalarClassificationModel(tm.TensorToScalarClassificationModel): 

43 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=64, 

44 check_input_shape=True, check_input_columns=True): 

45 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) 

46 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, 

47 batch_size=batch_size) 

48 

49 def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray: 

50 return self.wrapped_model.predict(x) 

51 

52 def _fit_to_array(self, x: np.ndarray, y: np.ndarray): 

53 self.wrapped_model.fit(x, y) 

54 

55 

56# noinspection DuplicatedCode 

57class PLTensorToScalarRegressionModel(tm.TensorToScalarRegressionModel): 

58 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, 

59 check_input_shape=True, check_input_columns=True): 

60 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) 

61 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, 

62 batch_size=batch_size) 

63 

64 def _predict_array(self, x: np.ndarray) -> np.ndarray: 

65 return self.wrapped_model.predict(x) 

66 

67 def _fit_to_array(self, x: np.ndarray, y: np.ndarray): 

68 self.wrapped_model.fit(x, y) 

69 

70 

71# noinspection DuplicatedCode 

72class PLTensorToTensorClassificationModel(tm.TensorToTensorClassificationModel): 

73 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, 

74 check_input_shape=True, check_input_columns=True): 

75 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) 

76 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, 

77 batch_size=batch_size) 

78 

79 def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray: 

80 return self.wrapped_model.predict(x) 

81 

82 def _fit_to_array(self, x: np.ndarray, y: np.ndarray): 

83 self.wrapped_model.fit(x, y) 

84 

85 

86# noinspection DuplicatedCode 

87class PLTensorToTensorRegressionModel(tm.TensorToTensorRegressionModel): 

88 def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, 

89 check_input_shape=True, check_input_columns=True): 

90 super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) 

91 self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, 

92 batch_size=batch_size) 

93 

94 def _predict_array(self, x: np.ndarray) -> np.ndarray: 

95 return self.wrapped_model.predict(x) 

96 

97 def _fit_to_array(self, x: np.ndarray, y: np.ndarray): 

98 self.wrapped_model.fit(x, y)