Coverage for src/sensai/data/dataset.py: 73%
37 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1"""
2This module contains sample datasets to facilitate testing and development.
3"""
4from abc import ABC, abstractmethod
6import sklearn.datasets
8from sensai.data import InputOutputData
9import pandas as pd
11from sensai.util.string import ToStringMixin
14class DataSet(ToStringMixin, ABC):
15 @abstractmethod
16 def load_io_data(self) -> InputOutputData:
17 pass
20class DataSetClassificationIris(DataSet):
21 def load_io_data(self) -> InputOutputData:
22 iris_data = sklearn.datasets.load_iris()
23 input_df = pd.DataFrame(iris_data["data"], columns=iris_data["feature_names"]).reset_index(drop=True)
24 output_df = pd.DataFrame({"class": [iris_data["target_names"][idx] for idx in iris_data["target"]]}) \
25 .reset_index(drop=True)
26 return InputOutputData(input_df, output_df)
29class DataSetClassificationTitanicSurvival(DataSet):
30 URL = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
32 COL_INDEX = "PassengerId"
33 """
34 unique identifier for each passenger
35 """
36 COL_SURVIVAL = "Survived"
37 """
38 0 = No, 1 = Yes
39 """
40 COL_NAME = "Name"
41 """
42 passenger name
43 """
44 COL_PASSENGER_CLASS = "Pclass"
45 """
46 Ticket class as an integer (1 = first, 2 = second, 3 = third)
47 """
48 COL_SEX = "Sex"
49 """
50 'male' or 'female'
51 """
52 COL_AGE_YEARS = "Age"
53 """
54 age in years (integer)
55 """
56 COL_SIBLINGS_SPOUSES = "SibSp"
57 """
58 number of siblings/spouses aboard the Titanic
59 """
60 COL_PARENTS_CHILDREN = "Parch"
61 """
62 number of parents/children aboard the Titanic
63 """
64 COL_FARE_PRICE = "Fare"
65 """
66 amount of money paid for the ticket
67 """
68 COL_CABIN = "Cabin"
69 """
70 the cabin number (if available)
71 """
72 COL_PORT_EMBARKED = "Embarked"
73 """
74 port of embarkation ('C' = Cherbourg, 'Q' = Queenstown, 'S' = Southampton)
75 """
76 COL_TICKET = "Ticket"
77 """
78 the ticket number
79 """
80 COLS_METADATA = [COL_NAME, COL_TICKET, COL_CABIN]
81 """
82 list of columns containing meta-data which are not useful for generalising prediction models
83 """
85 def __init__(self, drop_metadata_columns: bool = False):
86 """
87 :param drop_metadata_columns: whether to drop meta-data columns which are not useful for a
88 generalising prediction model
89 """
90 self.drop_metadata_columns = drop_metadata_columns
92 def load_io_data(self) -> InputOutputData:
93 df = pd.read_csv(self.URL).set_index(self.COL_INDEX, drop=True)
94 if self.drop_metadata_columns:
95 df.drop(columns=self.COLS_METADATA, inplace=True)
96 return InputOutputData.from_data_frame(df, self.COL_SURVIVAL)