Source code for sensai.tracking.mlflow_tracking

import re
from typing import Dict, Any

import mlflow
from matplotlib import pyplot as plt

from .tracking_base import TrackedExperiment, TrackingContext
from .. import VectorModelBase
from ..util import logging


[docs]class MLFlowTrackingContext(TrackingContext): def __init__(self, name: str, experiment: "MLFlowExperiment", run_id=None, description=""): super().__init__(name, experiment) if run_id is not None: run = mlflow.start_run(run_id) else: run = mlflow.start_run(run_name=name, description=description) self.run = run @staticmethod def _metric_name(name: str): result = re.sub(r"\[(.*?)\]", r"_\1", name) # replace "foo[bar]" with "foo_bar" result = re.sub(r"[^a-zA-Z0-9-_. /]+", "_", result) # replace sequences of unsupported chars with underscore return result def _track_metrics(self, metrics: Dict[str, float]): metrics = {self._metric_name(name): value for name, value in metrics.items()} mlflow.log_metrics(metrics)
[docs] def track_figure(self, name: str, fig: plt.Figure): mlflow.log_figure(fig, name + ".png")
[docs] def track_text(self, name: str, content: str): mlflow.log_text(content, name + ".txt")
[docs] def track_tag(self, tag_name: str, tag_value: str): mlflow.set_tag(tag_name, tag_value)
def _end(self): mlflow.end_run()
[docs]class MLFlowExperiment(TrackedExperiment[MLFlowTrackingContext]): def __init__(self, experiment_name: str, tracking_uri: str, additional_logging_values_dict=None, context_prefix: str = "", add_log_to_all_contexts=False): """ :param experiment_name: the name of the experiment, which should be the same for all models of the same kind (i.e. all models evaluated under the same conditions) :param tracking_uri: the URI of the server (if any); use "" to track in the local file system :param additional_logging_values_dict: :param context_prefix: a prefix to add to all contexts that are created within the experiment. This can be used to add an identifier of a certain execution/run, such that the actual context name passed to `begin_context` can be concise (e.g. just model name). :param add_log_to_all_contexts: whether to enable in-memory logging and add the resulting log file to all tracking contexts that are generated for this experiment upon context exit (or process termination if it is not cleanly closed) """ mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(experiment_name=experiment_name) super().__init__(context_prefix=context_prefix, additional_logging_values_dict=additional_logging_values_dict) self._run_name_to_id = {} self.add_log_to_all_contexts = add_log_to_all_contexts if self.add_log_to_all_contexts: logging.add_memory_logger() def _track_values(self, values_dict: Dict[str, Any]): with mlflow.start_run(): mlflow.log_metrics(values_dict) def _create_tracking_context(self, name: str, description: str) -> MLFlowTrackingContext: run_id = self._run_name_to_id.get(name) print(f"create {name}") context = MLFlowTrackingContext(name, self, run_id=run_id, description=description) self._run_name_to_id[name] = context.run.info.run_id return context
[docs] def begin_context_for_model(self, model: VectorModelBase): context = super().begin_context_for_model(model) context.track_tag("ModelClass", model.__class__.__name__) return context
[docs] def end_context(self, instance: MLFlowTrackingContext): print(f"end {instance}") if self.add_log_to_all_contexts: instance.track_text("log", logging.get_memory_log()) super().end_context(instance)