Coverage for src/sensai/tracking/mlflow_tracking.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import re
2from typing import Dict, Any, Optional
4import mlflow
5from matplotlib import pyplot as plt
7from .tracking_base import TrackedExperiment, TrackingContext
8from .. import VectorModelBase
9from ..util import logging
12class MLFlowTrackingContext(TrackingContext):
13 def __init__(self, name: str, experiment: "MLFlowExperiment", run_id=None, description=""):
14 super().__init__(name, experiment)
15 if run_id is not None:
16 run = mlflow.start_run(run_id)
17 else:
18 run = mlflow.start_run(run_name=name, description=description)
19 self.run = run
20 self.log_handler: Optional[logging.MemoryStreamHandler] = None
22 @staticmethod
23 def _metric_name(name: str):
24 result = re.sub(r"\[(.*?)\]", r"_\1", name) # replace "foo[bar]" with "foo_bar"
25 result = re.sub(r"[^a-zA-Z0-9-_. /]+", "_", result) # replace sequences of unsupported chars with underscore
26 return result
28 def _track_metrics(self, metrics: Dict[str, float]):
29 metrics = {self._metric_name(name): value for name, value in metrics.items()}
30 mlflow.log_metrics(metrics)
32 def track_figure(self, name: str, fig: plt.Figure):
33 mlflow.log_figure(fig, name + ".png")
35 def track_text(self, name: str, content: str):
36 mlflow.log_text(content, name + ".txt")
38 def track_tag(self, tag_name: str, tag_value: str):
39 mlflow.set_tag(tag_name, tag_value)
41 def _end(self):
42 mlflow.end_run()
45class MLFlowExperiment(TrackedExperiment[MLFlowTrackingContext]):
46 def __init__(self, experiment_name: str, tracking_uri: str, additional_logging_values_dict=None,
47 context_prefix: str = "", add_log_to_all_contexts=False):
48 """
49 :param experiment_name: the name of the experiment, which should be the same for all models of the same kind (i.e. all models evaluated
50 under the same conditions)
51 :param tracking_uri: the URI of the server (if any); use "" to track in the local file system
52 :param additional_logging_values_dict:
53 :param context_prefix: a prefix to add to all contexts that are created within the experiment. This can be used to add
54 an identifier of a certain execution/run, such that the actual context name passed to `begin_context` can be concise (e.g. just model name).
55 :param add_log_to_all_contexts: whether to enable in-memory logging and add the respective log to each context
56 """
57 mlflow.set_tracking_uri(tracking_uri)
58 mlflow.set_experiment(experiment_name=experiment_name)
59 super().__init__(context_prefix=context_prefix, additional_logging_values_dict=additional_logging_values_dict)
60 self._run_name_to_id = {}
61 self.add_log_to_all_contexts = add_log_to_all_contexts
63 def _track_values(self, values_dict: Dict[str, Any]):
64 with mlflow.start_run():
65 mlflow.log_metrics(values_dict)
67 def _create_tracking_context(self, name: str, description: str) -> MLFlowTrackingContext:
68 run_id = self._run_name_to_id.get(name)
69 print(f"create {name}")
70 context = MLFlowTrackingContext(name, self, run_id=run_id, description=description)
71 self._run_name_to_id[name] = context.run.info.run_id
72 return context
74 def begin_context_for_model(self, model: VectorModelBase):
75 context = super().begin_context_for_model(model)
76 if self.add_log_to_all_contexts:
77 context.log_handler = logging.add_memory_logger()
78 context.track_tag("ModelClass", model.__class__.__name__)
79 return context
81 def end_context(self, instance: MLFlowTrackingContext):
82 print(f"end {instance}")
83 if instance.log_handler is not None:
84 instance.track_text("log", instance.log_handler.get_log())
85 super().end_context(instance)