Source code for sensai.tracking.mlflow_tracking
import re
from typing import Dict, Any, Optional
import mlflow
from matplotlib import pyplot as plt
from .tracking_base import TrackedExperiment, TrackingContext
from .. import VectorModelBase
from ..util import logging
[docs]class MLFlowTrackingContext(TrackingContext):
def __init__(self, name: str, experiment: "MLFlowExperiment", run_id=None, description=""):
super().__init__(name, experiment)
if run_id is not None:
run = mlflow.start_run(run_id)
else:
run = mlflow.start_run(run_name=name, description=description)
self.run = run
self.log_handler: Optional[logging.MemoryStreamHandler] = None
@staticmethod
def _metric_name(name: str):
result = re.sub(r"\[(.*?)\]", r"_\1", name) # replace "foo[bar]" with "foo_bar"
result = re.sub(r"[^a-zA-Z0-9-_. /]+", "_", result) # replace sequences of unsupported chars with underscore
return result
def _track_metrics(self, metrics: Dict[str, float]):
metrics = {self._metric_name(name): value for name, value in metrics.items()}
mlflow.log_metrics(metrics)
[docs] def track_figure(self, name: str, fig: plt.Figure):
mlflow.log_figure(fig, name + ".png")
[docs] def track_text(self, name: str, content: str):
mlflow.log_text(content, name + ".txt")
[docs] def track_tag(self, tag_name: str, tag_value: str):
mlflow.set_tag(tag_name, tag_value)
def _end(self):
mlflow.end_run()
[docs]class MLFlowExperiment(TrackedExperiment[MLFlowTrackingContext]):
def __init__(self, experiment_name: str, tracking_uri: str, additional_logging_values_dict=None,
context_prefix: str = "", add_log_to_all_contexts=False):
"""
:param experiment_name: the name of the experiment, which should be the same for all models of the same kind (i.e. all models evaluated
under the same conditions)
:param tracking_uri: the URI of the server (if any); use "" to track in the local file system
:param additional_logging_values_dict:
:param context_prefix: a prefix to add to all contexts that are created within the experiment. This can be used to add
an identifier of a certain execution/run, such that the actual context name passed to `begin_context` can be concise (e.g. just model name).
:param add_log_to_all_contexts: whether to enable in-memory logging and add the respective log to each context
"""
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name=experiment_name)
super().__init__(context_prefix=context_prefix, additional_logging_values_dict=additional_logging_values_dict)
self._run_name_to_id = {}
self.add_log_to_all_contexts = add_log_to_all_contexts
def _track_values(self, values_dict: Dict[str, Any]):
with mlflow.start_run():
mlflow.log_metrics(values_dict)
def _create_tracking_context(self, name: str, description: str) -> MLFlowTrackingContext:
run_id = self._run_name_to_id.get(name)
print(f"create {name}")
context = MLFlowTrackingContext(name, self, run_id=run_id, description=description)
self._run_name_to_id[name] = context.run.info.run_id
return context
[docs] def begin_context_for_model(self, model: VectorModelBase):
context = super().begin_context_for_model(model)
if self.add_log_to_all_contexts:
context.log_handler = logging.add_memory_logger()
context.track_tag("ModelClass", model.__class__.__name__)
return context
[docs] def end_context(self, instance: MLFlowTrackingContext):
print(f"end {instance}")
if instance.log_handler is not None:
instance.track_text("log", instance.log_handler.get_log())
super().end_context(instance)