Coverage for src/sensai/data/dataset.py: 73%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1""" 

2This module contains sample datasets to facilitate testing and development. 

3""" 

4from abc import ABC, abstractmethod 

5 

6import sklearn.datasets 

7 

8from sensai.data import InputOutputData 

9import pandas as pd 

10 

11from sensai.util.string import ToStringMixin 

12 

13 

14class DataSet(ToStringMixin, ABC): 

15 @abstractmethod 

16 def load_io_data(self) -> InputOutputData: 

17 pass 

18 

19 

20class DataSetClassificationIris(DataSet): 

21 def load_io_data(self) -> InputOutputData: 

22 iris_data = sklearn.datasets.load_iris() 

23 input_df = pd.DataFrame(iris_data["data"], columns=iris_data["feature_names"]).reset_index(drop=True) 

24 output_df = pd.DataFrame({"class": [iris_data["target_names"][idx] for idx in iris_data["target"]]}) \ 

25 .reset_index(drop=True) 

26 return InputOutputData(input_df, output_df) 

27 

28 

29class DataSetClassificationTitanicSurvival(DataSet): 

30 URL = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv" 

31 

32 COL_INDEX = "PassengerId" 

33 """ 

34 unique identifier for each passenger 

35 """ 

36 COL_SURVIVAL = "Survived" 

37 """ 

38 0 = No, 1 = Yes 

39 """ 

40 COL_NAME = "Name" 

41 """ 

42 passenger name 

43 """ 

44 COL_PASSENGER_CLASS = "Pclass" 

45 """ 

46 Ticket class as an integer (1 = first, 2 = second, 3 = third) 

47 """ 

48 COL_SEX = "Sex" 

49 """ 

50 'male' or 'female' 

51 """ 

52 COL_AGE_YEARS = "Age" 

53 """ 

54 age in years (integer) 

55 """ 

56 COL_SIBLINGS_SPOUSES = "SibSp" 

57 """ 

58 number of siblings/spouses aboard the Titanic 

59 """ 

60 COL_PARENTS_CHILDREN = "Parch" 

61 """ 

62 number of parents/children aboard the Titanic 

63 """ 

64 COL_FARE_PRICE = "Fare" 

65 """ 

66 amount of money paid for the ticket 

67 """ 

68 COL_CABIN = "Cabin" 

69 """ 

70 the cabin number (if available) 

71 """ 

72 COL_PORT_EMBARKED = "Embarked" 

73 """ 

74 port of embarkation ('C' = Cherbourg, 'Q' = Queenstown, 'S' = Southampton) 

75 """ 

76 COL_TICKET = "Ticket" 

77 """ 

78 the ticket number 

79 """ 

80 COLS_METADATA = [COL_NAME, COL_TICKET, COL_CABIN] 

81 """ 

82 list of columns containing meta-data which are not useful for generalising prediction models 

83 """ 

84 

85 def __init__(self, drop_metadata_columns: bool = False): 

86 """ 

87 :param drop_metadata_columns: whether to drop meta-data columns which are not useful for a 

88 generalising prediction model 

89 """ 

90 self.drop_metadata_columns = drop_metadata_columns 

91 

92 def load_io_data(self) -> InputOutputData: 

93 df = pd.read_csv(self.URL).set_index(self.COL_INDEX, drop=True) 

94 if self.drop_metadata_columns: 

95 df.drop(columns=self.COLS_METADATA, inplace=True) 

96 return InputOutputData.from_data_frame(df, self.COL_SURVIVAL)