Source code for sensai.pytorch_lightning.pl_models

import numpy as np
from pytorch_lightning import Trainer, LightningModule
from torch import tensor

from .. import tensor_model as tm
from ..data import InputOutputArrays, DataSplitterFractional


def _fit_model_with_trainer(model: LightningModule, trainer: Trainer, io_data,
        batch_size: int, splitter: DataSplitterFractional = None):
    if splitter is not None:
        train_io_data, validation_io_data = splitter.split(io_data)
        train_data_loader = train_io_data.toTorchDataLoader(batchSize=batch_size)
        val_data_loader = validation_io_data.toTorchDataLoader(batchSize=batch_size)
    else:
        train_data_loader = io_data.to_torch_data_loader(batch_size=batch_size)
        val_data_loader = None
    trainer.fit(model, train_data_loader, val_dataloaders=val_data_loader)


[docs]class PLWrappedModel: def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32): if not 0 <= validation_fraction <= 1: raise ValueError(f"Invalid validationFraction: {validation_fraction}. Has to be in interval [0, 1]") self.trainer = trainer self.model = model self.validationFraction = validation_fraction self.shuffle = shuffle self.batchSize = batch_size
[docs] def fit(self, x: np.ndarray, y: np.ndarray): io_data = InputOutputArrays(x, y) splitter = DataSplitterFractional(1 - self.validationFraction, shuffle=self.shuffle) _fit_model_with_trainer(self.model, self.trainer, io_data, self.batchSize, splitter=splitter)
[docs] def predict(self, x: np.ndarray) -> np.ndarray: x = tensor(x) return self.model(x).detach().cpu().numpy()
# noinspection DuplicatedCode
[docs]class PLTensorToScalarClassificationModel(tm.TensorToScalarClassificationModel): def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=64, check_input_shape=True, check_input_columns=True): super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, batch_size=batch_size) def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray: return self.wrapped_model.predict(x) def _fit_to_array(self, x: np.ndarray, y: np.ndarray): self.wrapped_model.fit(x, y)
# noinspection DuplicatedCode
[docs]class PLTensorToScalarRegressionModel(tm.TensorToScalarRegressionModel): def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, check_input_shape=True, check_input_columns=True): super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, batch_size=batch_size) def _predict_array(self, x: np.ndarray) -> np.ndarray: return self.wrapped_model.predict(x) def _fit_to_array(self, x: np.ndarray, y: np.ndarray): self.wrapped_model.fit(x, y)
# noinspection DuplicatedCode
[docs]class PLTensorToTensorClassificationModel(tm.TensorToTensorClassificationModel): def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, check_input_shape=True, check_input_columns=True): super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, batch_size=batch_size) def _predict_probabilities_array(self, x: np.ndarray) -> np.ndarray: return self.wrapped_model.predict(x) def _fit_to_array(self, x: np.ndarray, y: np.ndarray): self.wrapped_model.fit(x, y)
# noinspection DuplicatedCode
[docs]class PLTensorToTensorRegressionModel(tm.TensorToTensorRegressionModel): def __init__(self, model: LightningModule, trainer: Trainer, validation_fraction=0.1, shuffle=True, batch_size=32, check_input_shape=True, check_input_columns=True): super().__init__(check_input_shape=check_input_shape, check_input_columns=check_input_columns) self.wrapped_model = PLWrappedModel(model, trainer, validation_fraction=validation_fraction, shuffle=shuffle, batch_size=batch_size) def _predict_array(self, x: np.ndarray) -> np.ndarray: return self.wrapped_model.predict(x) def _fit_to_array(self, x: np.ndarray, y: np.ndarray): self.wrapped_model.fit(x, y)