Coverage for src/sensai/torch/torchtext.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1from typing import Dict, Generator, Tuple, Optional, Union 

2 

3import pandas as pd 

4import torch 

5import torchtext 

6 

7from .torch_data import to_tensor, TorchDataSet, TorchDataSetProvider 

8 

9 

10class TorchtextDataSetFromDataFrame(torchtext.data.Dataset): 

11 """ 

12 A specialisation of torchtext.data.Dataset, where the data is taken from a pandas.DataFrame 

13 """ 

14 def __init__(self, df: pd.DataFrame, fields: Dict[str, torchtext.data.Field]): 

15 """ 

16 :param df: the data frame from which to obtain the data 

17 :param fields: a mapping from column names in the given data frame to torchtext fields, i.e. 

18 the keys are the columns to read and the values are the fields to use for generated Example instances 

19 """ 

20 examples = df.apply(self._exampleFromSeries, args=(fields,), axis=1).tolist() 

21 fields = dict(fields) 

22 super().__init__(examples, fields) 

23 

24 @classmethod 

25 def _exampleFromSeries(cls, series: pd.Series, fields: Dict[str, torchtext.data.Field]): 

26 return cls._exampleFromDict(series.to_dict(), fields) 

27 

28 @classmethod 

29 def _exampleFromDict(cls, d: dict, fields: Dict[str, torchtext.data.Field]): 

30 ex = torchtext.data.Example() 

31 for key, field in fields.items(): 

32 if key not in d: 

33 raise ValueError("Specified key {} was not found in " 

34 "the input data".format(key)) 

35 if field is not None: 

36 setattr(ex, key, field.preprocess(d[key])) 

37 else: 

38 setattr(ex, key, d[key]) 

39 return ex 

40 

41 

42class TorchDataSetFromTorchtextDataSet(TorchDataSet): 

43 def __init__(self, dataSet: torchtext.data.Dataset, inputField: str, outputField: Optional[str], cuda: bool): 

44 self.outputField = outputField 

45 self.inputField = inputField 

46 self.dataSet = dataSet 

47 self.cuda = cuda 

48 

49 def iter_batches(self, batch_size: int, shuffle: bool = False, input_only=False) -> Generator[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], None, None]: 

50 iterator = torchtext.data.BucketIterator(self.dataSet, 

51 batch_size=batch_size, 

52 sort_key=lambda x: len(x.text), 

53 sort_within_batch=False) 

54 

55 for batch in iterator: 

56 x = to_tensor(getattr(batch, self.inputField), self.cuda) 

57 if not input_only and self.outputField is not None: 

58 y = to_tensor(getattr(batch, self.outputField), self.cuda) 

59 yield x, y 

60 else: 

61 yield x 

62 

63 def size(self) -> Optional[int]: 

64 return len(self.dataSet) 

65 

66 

67class TorchDataSetProviderFromTorchtextDataSet(TorchDataSetProvider): 

68 def __init__(self, dataSet: torchtext.data.Dataset, inputField: str, outputField: str, cuda: bool, model_output_dim, input_dim=None): 

69 super().__init__(model_output_dim=model_output_dim, input_dim=input_dim) 

70 self.dataSet = dataSet 

71 self.outputField = outputField 

72 self.inputField = inputField 

73 self.cuda = cuda 

74 

75 def provide_split(self, fractional_size_of_first_set: float) -> Tuple[TorchDataSet, TorchDataSet]: 

76 d1, d2 = self.dataSet.split(fractional_size_of_first_set) 

77 return self._createDataSet(d1), self._createDataSet(d2) 

78 

79 def _createDataSet(self, d: torchtext.data.Dataset): 

80 return TorchDataSetFromTorchtextDataSet(d, self.inputField, self.outputField, self.cuda)