Coverage for src/sensai/tracking/tracking_base.py: 64%
107 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1from abc import ABC, abstractmethod
2from typing import Dict, Any, Optional, Generic, TypeVar, List
4from matplotlib import pyplot as plt
6from ..util import count_none
7from ..util.deprecation import deprecated
8from ..vector_model import VectorModelBase
11class TrackingContext(ABC):
12 def __init__(self, name: str, experiment: Optional["TrackedExperiment"]):
13 # NOTE: `experiment` is optional only because of DummyTrackingContext
14 self.name = name
15 self._experiment = experiment
16 self._isRunning = False
18 @staticmethod
19 def from_optional_experiment(experiment: Optional["TrackedExperiment"], model: Optional[VectorModelBase] = None,
20 name: Optional[str] = None, description: str = ""):
21 if experiment is None:
22 return DummyTrackingContext(name)
23 else:
24 if count_none(name, model) != 1:
25 raise ValueError("Must provide exactly one of {model, name}")
26 if model is not None:
27 return experiment.begin_context_for_model(model)
28 else:
29 return experiment.begin_context(name, description)
31 def is_enabled(self):
32 """
33 :return: True if tracking is enabled, i.e. whether results can be saved via this context
34 """
35 return True
37 @abstractmethod
38 def _track_metrics(self, metrics: Dict[str, float]):
39 pass
41 def track_metrics(self, metrics: Dict[str, float], predicted_var_name: Optional[str] = None):
42 """
43 :param metrics: the metrics to be logged
44 :param predicted_var_name: the name of the predicted variable for the case where there is more than one. If it is provided,
45 the variable name will be prepended to every metric name.
46 """
47 if predicted_var_name is not None:
48 metrics = {f"{predicted_var_name}_{k}": v for k, v in metrics.items()}
49 self._track_metrics(metrics)
51 @abstractmethod
52 def track_figure(self, name: str, fig: plt.Figure):
53 """
54 :param name: the name of the figure (not a filename, should not include file extension)
55 :param fig: the figure
56 """
57 pass
59 @abstractmethod
60 def track_text(self, name: str, content: str):
61 """
62 :param name: the name of the text (not a filename, should not include file extension)
63 :param content: the content (arbitrarily long text, e.g. a log)
64 """
65 pass
67 def __enter__(self):
68 self._isRunning = True
69 return self
71 def __exit__(self, exc_type, exc_value, exc_traceback):
72 self.end()
74 @abstractmethod
75 def _end(self):
76 pass
78 def end(self):
79 # first end the context in the experiment (which may add final stuff)
80 if self._isRunning:
81 if self._experiment is not None:
82 self._experiment.end_context(self)
83 self._isRunning = False
84 # then end the context for good
85 self._end()
88class DummyTrackingContext(TrackingContext):
89 """
90 A dummy tracking context which performs no actual tracking.
91 It is useful to avoid having to write conditional tracking code for the case where there isn't a tracked experiment.
92 """
93 def __init__(self, name):
94 super().__init__(name, None)
96 def is_enabled(self):
97 return False
99 def _track_metrics(self, metrics: Dict[str, float]):
100 pass
102 def track_figure(self, name: str, fig: plt.Figure):
103 pass
105 def track_text(self, name: str, content: str):
106 pass
108 def _end(self):
109 pass
112TContext = TypeVar("TContext", bound=TrackingContext)
115class TrackedExperiment(Generic[TContext], ABC):
116 def __init__(self, context_prefix: str = "", additional_logging_values_dict=None):
117 """
118 Base class for tracking
119 :param additional_logging_values_dict: additional values to be logged for each run
120 """
121 # TODO additional_logging_values_dict probably needs to be removed
122 self.instancePrefix = context_prefix
123 self.additionalLoggingValuesDict = additional_logging_values_dict
124 self._contexts: List[TContext] = []
126 @deprecated("Use a tracking context instead")
127 def track_values(self, values_dict: Dict[str, Any], add_values_dict: Dict[str, Any] = None):
128 values_dict = dict(values_dict)
129 if add_values_dict is not None:
130 values_dict.update(add_values_dict)
131 if self.additionalLoggingValuesDict is not None:
132 values_dict.update(self.additionalLoggingValuesDict)
133 self._track_values(values_dict)
135 @abstractmethod
136 def _track_values(self, values_dict: Dict[str, Any]):
137 pass
139 @abstractmethod
140 def _create_tracking_context(self, name: str, description: str) -> TContext:
141 pass
143 def begin_context(self, name: str, description: str = "") -> TContext:
144 """
145 Begins a context in which actual information will be tracked.
146 The returned object is a context manager, which can be used in a with-statement.
148 :param name: the name of the context (e.g. model name)
149 :param description: a description (e.g. full model parameters/specification)
150 :return: the context, which can subsequently be used to track information
151 """
152 instance = self._create_tracking_context(self.instancePrefix + name, description)
153 self._contexts.append(instance)
154 return instance
156 def begin_context_for_model(self, model: VectorModelBase):
157 """
158 Begins a tracking context for the case where we want to track information about a model (wrapper around `begin_context` for convenience).
159 The model name is used as the context name, and the model's string representation is used as the description.
160 The returned object is a context manager, which can be used in a with-statement.
162 :param model: the model
163 :return: the context, which can subsequently be used to track information
164 """
165 return self.begin_context(model.get_name(), model.pprints())
167 def end_context(self, instance: TContext):
168 running_instance = self._contexts[-1]
169 if instance != running_instance:
170 raise ValueError(f"Passed instance ({instance}) is not the currently running instance ({running_instance})")
171 self._contexts.pop()
173 def __del__(self):
174 # make sure all contexts that are still running are eventually closed
175 for c in reversed(self._contexts):
176 c.end()
179class TrackingMixin(ABC):
180 _objectId2trackedExperiment = {}
182 def set_tracked_experiment(self, tracked_experiment: Optional[TrackedExperiment]):
183 self._objectId2trackedExperiment[id(self)] = tracked_experiment
185 def unset_tracked_experiment(self):
186 self.set_tracked_experiment(None)
188 @property
189 def tracked_experiment(self) -> Optional[TrackedExperiment]:
190 return self._objectId2trackedExperiment.get(id(self))
192 def begin_optional_tracking_context_for_model(self, model: VectorModelBase, track: bool = True) -> TrackingContext:
193 """
194 Begins a tracking context for the given model; the returned object is a context manager and therefore method should
195 preferably be used in a `with` statement.
196 This method can be called regardless of whether there actually is a tracked experiment (hence the term 'optional').
197 If there is no tracked experiment, calling methods on the returned object has no effect.
198 Furthermore, tracking can be disabled by passing `track=False` even if a tracked experiment is present.
200 :param model: the model for which to begin tracking
201 :paraqm track: whether tracking shall be enabled; if False, force use of a dummy context which performs no actual tracking even
202 if a tracked experiment is present
203 :return: a context manager that can be used to track results for the given model
204 """
205 return TrackingContext.from_optional_experiment(self.tracked_experiment if track else None, model=model)