Coverage for src/sensai/torch/torchtext.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from typing import Dict, Generator, Tuple, Optional, Union
3import pandas as pd
4import torch
5import torchtext
7from .torch_data import to_tensor, TorchDataSet, TorchDataSetProvider
10class TorchtextDataSetFromDataFrame(torchtext.data.Dataset):
11 """
12 A specialisation of torchtext.data.Dataset, where the data is taken from a pandas.DataFrame
13 """
14 def __init__(self, df: pd.DataFrame, fields: Dict[str, torchtext.data.Field]):
15 """
16 :param df: the data frame from which to obtain the data
17 :param fields: a mapping from column names in the given data frame to torchtext fields, i.e.
18 the keys are the columns to read and the values are the fields to use for generated Example instances
19 """
20 examples = df.apply(self._exampleFromSeries, args=(fields,), axis=1).tolist()
21 fields = dict(fields)
22 super().__init__(examples, fields)
24 @classmethod
25 def _exampleFromSeries(cls, series: pd.Series, fields: Dict[str, torchtext.data.Field]):
26 return cls._exampleFromDict(series.to_dict(), fields)
28 @classmethod
29 def _exampleFromDict(cls, d: dict, fields: Dict[str, torchtext.data.Field]):
30 ex = torchtext.data.Example()
31 for key, field in fields.items():
32 if key not in d:
33 raise ValueError("Specified key {} was not found in "
34 "the input data".format(key))
35 if field is not None:
36 setattr(ex, key, field.preprocess(d[key]))
37 else:
38 setattr(ex, key, d[key])
39 return ex
42class TorchDataSetFromTorchtextDataSet(TorchDataSet):
43 def __init__(self, dataSet: torchtext.data.Dataset, inputField: str, outputField: Optional[str], cuda: bool):
44 self.outputField = outputField
45 self.inputField = inputField
46 self.dataSet = dataSet
47 self.cuda = cuda
49 def iter_batches(self, batch_size: int, shuffle: bool = False, input_only=False) -> Generator[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], None, None]:
50 iterator = torchtext.data.BucketIterator(self.dataSet,
51 batch_size=batch_size,
52 sort_key=lambda x: len(x.text),
53 sort_within_batch=False)
55 for batch in iterator:
56 x = to_tensor(getattr(batch, self.inputField), self.cuda)
57 if not input_only and self.outputField is not None:
58 y = to_tensor(getattr(batch, self.outputField), self.cuda)
59 yield x, y
60 else:
61 yield x
63 def size(self) -> Optional[int]:
64 return len(self.dataSet)
67class TorchDataSetProviderFromTorchtextDataSet(TorchDataSetProvider):
68 def __init__(self, dataSet: torchtext.data.Dataset, inputField: str, outputField: str, cuda: bool, model_output_dim, input_dim=None):
69 super().__init__(model_output_dim=model_output_dim, input_dim=input_dim)
70 self.dataSet = dataSet
71 self.outputField = outputField
72 self.inputField = inputField
73 self.cuda = cuda
75 def provide_split(self, fractional_size_of_first_set: float) -> Tuple[TorchDataSet, TorchDataSet]:
76 d1, d2 = self.dataSet.split(fractional_size_of_first_set)
77 return self._createDataSet(d1), self._createDataSet(d2)
79 def _createDataSet(self, d: torchtext.data.Dataset):
80 return TorchDataSetFromTorchtextDataSet(d, self.inputField, self.outputField, self.cuda)