Coverage for src/sensai/tracking/tracking_base.py: 64%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from abc import ABC, abstractmethod 

2from typing import Dict, Any, Optional, Generic, TypeVar, List 

3 

4from matplotlib import pyplot as plt 

5 

6from ..util import count_none 

7from ..util.deprecation import deprecated 

8from ..vector_model import VectorModelBase 

9 

10 

11class TrackingContext(ABC): 

12 def __init__(self, name: str, experiment: Optional["TrackedExperiment"]): 

13 # NOTE: `experiment` is optional only because of DummyTrackingContext 

14 self.name = name 

15 self._experiment = experiment 

16 self._isRunning = False 

17 

18 @staticmethod 

19 def from_optional_experiment(experiment: Optional["TrackedExperiment"], model: Optional[VectorModelBase] = None, 

20 name: Optional[str] = None, description: str = ""): 

21 if experiment is None: 

22 return DummyTrackingContext(name) 

23 else: 

24 if count_none(name, model) != 1: 

25 raise ValueError("Must provide exactly one of {model, name}") 

26 if model is not None: 

27 return experiment.begin_context_for_model(model) 

28 else: 

29 return experiment.begin_context(name, description) 

30 

31 def is_enabled(self): 

32 """ 

33 :return: True if tracking is enabled, i.e. whether results can be saved via this context 

34 """ 

35 return True 

36 

37 @abstractmethod 

38 def _track_metrics(self, metrics: Dict[str, float]): 

39 pass 

40 

41 def track_metrics(self, metrics: Dict[str, float], predicted_var_name: Optional[str] = None): 

42 """ 

43 :param metrics: the metrics to be logged 

44 :param predicted_var_name: the name of the predicted variable for the case where there is more than one. If it is provided, 

45 the variable name will be prepended to every metric name. 

46 """ 

47 if predicted_var_name is not None: 

48 metrics = {f"{predicted_var_name}_{k}": v for k, v in metrics.items()} 

49 self._track_metrics(metrics) 

50 

51 @abstractmethod 

52 def track_figure(self, name: str, fig: plt.Figure): 

53 """ 

54 :param name: the name of the figure (not a filename, should not include file extension) 

55 :param fig: the figure 

56 """ 

57 pass 

58 

59 @abstractmethod 

60 def track_text(self, name: str, content: str): 

61 """ 

62 :param name: the name of the text (not a filename, should not include file extension) 

63 :param content: the content (arbitrarily long text, e.g. a log) 

64 """ 

65 pass 

66 

67 def __enter__(self): 

68 self._isRunning = True 

69 return self 

70 

71 def __exit__(self, exc_type, exc_value, exc_traceback): 

72 self.end() 

73 

74 @abstractmethod 

75 def _end(self): 

76 pass 

77 

78 def end(self): 

79 # first end the context in the experiment (which may add final stuff) 

80 if self._isRunning: 

81 if self._experiment is not None: 

82 self._experiment.end_context(self) 

83 self._isRunning = False 

84 # then end the context for good 

85 self._end() 

86 

87 

88class DummyTrackingContext(TrackingContext): 

89 """ 

90 A dummy tracking context which performs no actual tracking. 

91 It is useful to avoid having to write conditional tracking code for the case where there isn't a tracked experiment. 

92 """ 

93 def __init__(self, name): 

94 super().__init__(name, None) 

95 

96 def is_enabled(self): 

97 return False 

98 

99 def _track_metrics(self, metrics: Dict[str, float]): 

100 pass 

101 

102 def track_figure(self, name: str, fig: plt.Figure): 

103 pass 

104 

105 def track_text(self, name: str, content: str): 

106 pass 

107 

108 def _end(self): 

109 pass 

110 

111 

112TContext = TypeVar("TContext", bound=TrackingContext) 

113 

114 

115class TrackedExperiment(Generic[TContext], ABC): 

116 def __init__(self, context_prefix: str = "", additional_logging_values_dict=None): 

117 """ 

118 Base class for tracking 

119 :param additional_logging_values_dict: additional values to be logged for each run 

120 """ 

121 # TODO additional_logging_values_dict probably needs to be removed 

122 self.instancePrefix = context_prefix 

123 self.additionalLoggingValuesDict = additional_logging_values_dict 

124 self._contexts: List[TContext] = [] 

125 

126 @deprecated("Use a tracking context instead") 

127 def track_values(self, values_dict: Dict[str, Any], add_values_dict: Dict[str, Any] = None): 

128 values_dict = dict(values_dict) 

129 if add_values_dict is not None: 

130 values_dict.update(add_values_dict) 

131 if self.additionalLoggingValuesDict is not None: 

132 values_dict.update(self.additionalLoggingValuesDict) 

133 self._track_values(values_dict) 

134 

135 @abstractmethod 

136 def _track_values(self, values_dict: Dict[str, Any]): 

137 pass 

138 

139 @abstractmethod 

140 def _create_tracking_context(self, name: str, description: str) -> TContext: 

141 pass 

142 

143 def begin_context(self, name: str, description: str = "") -> TContext: 

144 """ 

145 Begins a context in which actual information will be tracked. 

146 The returned object is a context manager, which can be used in a with-statement. 

147 

148 :param name: the name of the context (e.g. model name) 

149 :param description: a description (e.g. full model parameters/specification) 

150 :return: the context, which can subsequently be used to track information 

151 """ 

152 instance = self._create_tracking_context(self.instancePrefix + name, description) 

153 self._contexts.append(instance) 

154 return instance 

155 

156 def begin_context_for_model(self, model: VectorModelBase): 

157 """ 

158 Begins a tracking context for the case where we want to track information about a model (wrapper around `begin_context` for convenience). 

159 The model name is used as the context name, and the model's string representation is used as the description. 

160 The returned object is a context manager, which can be used in a with-statement. 

161 

162 :param model: the model 

163 :return: the context, which can subsequently be used to track information 

164 """ 

165 return self.begin_context(model.get_name(), model.pprints()) 

166 

167 def end_context(self, instance: TContext): 

168 running_instance = self._contexts[-1] 

169 if instance != running_instance: 

170 raise ValueError(f"Passed instance ({instance}) is not the currently running instance ({running_instance})") 

171 self._contexts.pop() 

172 

173 def __del__(self): 

174 # make sure all contexts that are still running are eventually closed 

175 for c in reversed(self._contexts): 

176 c.end() 

177 

178 

179class TrackingMixin(ABC): 

180 _objectId2trackedExperiment = {} 

181 

182 def set_tracked_experiment(self, tracked_experiment: Optional[TrackedExperiment]): 

183 self._objectId2trackedExperiment[id(self)] = tracked_experiment 

184 

185 def unset_tracked_experiment(self): 

186 self.set_tracked_experiment(None) 

187 

188 @property 

189 def tracked_experiment(self) -> Optional[TrackedExperiment]: 

190 return self._objectId2trackedExperiment.get(id(self)) 

191 

192 def begin_optional_tracking_context_for_model(self, model: VectorModelBase, track: bool = True) -> TrackingContext: 

193 """ 

194 Begins a tracking context for the given model; the returned object is a context manager and therefore method should 

195 preferably be used in a `with` statement. 

196 This method can be called regardless of whether there actually is a tracked experiment (hence the term 'optional'). 

197 If there is no tracked experiment, calling methods on the returned object has no effect. 

198 Furthermore, tracking can be disabled by passing `track=False` even if a tracked experiment is present. 

199 

200 :param model: the model for which to begin tracking 

201 :paraqm track: whether tracking shall be enabled; if False, force use of a dummy context which performs no actual tracking even 

202 if a tracked experiment is present 

203 :return: a context manager that can be used to track results for the given model 

204 """ 

205 return TrackingContext.from_optional_experiment(self.tracked_experiment if track else None, model=model)