Coverage for src/sensai/tracking/mlflow_tracking.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import re 

2from typing import Dict, Any 

3 

4import mlflow 

5from matplotlib import pyplot as plt 

6 

7from .tracking_base import TrackedExperiment, TrackingContext 

8from .. import VectorModelBase 

9from ..util import logging 

10 

11 

12class MLFlowTrackingContext(TrackingContext): 

13 def __init__(self, name: str, experiment: "MLFlowExperiment", run_id=None, description=""): 

14 super().__init__(name, experiment) 

15 if run_id is not None: 

16 run = mlflow.start_run(run_id) 

17 else: 

18 run = mlflow.start_run(run_name=name, description=description) 

19 self.run = run 

20 

21 @staticmethod 

22 def _metric_name(name: str): 

23 result = re.sub(r"\[(.*?)\]", r"_\1", name) # replace "foo[bar]" with "foo_bar" 

24 result = re.sub(r"[^a-zA-Z0-9-_. /]+", "_", result) # replace sequences of unsupported chars with underscore 

25 return result 

26 

27 def _track_metrics(self, metrics: Dict[str, float]): 

28 metrics = {self._metric_name(name): value for name, value in metrics.items()} 

29 mlflow.log_metrics(metrics) 

30 

31 def track_figure(self, name: str, fig: plt.Figure): 

32 mlflow.log_figure(fig, name + ".png") 

33 

34 def track_text(self, name: str, content: str): 

35 mlflow.log_text(content, name + ".txt") 

36 

37 def track_tag(self, tag_name: str, tag_value: str): 

38 mlflow.set_tag(tag_name, tag_value) 

39 

40 def _end(self): 

41 mlflow.end_run() 

42 

43 

44class MLFlowExperiment(TrackedExperiment[MLFlowTrackingContext]): 

45 def __init__(self, experiment_name: str, tracking_uri: str, additional_logging_values_dict=None, 

46 context_prefix: str = "", add_log_to_all_contexts=False): 

47 """ 

48 :param experiment_name: the name of the experiment, which should be the same for all models of the same kind (i.e. all models evaluated 

49 under the same conditions) 

50 :param tracking_uri: the URI of the server (if any); use "" to track in the local file system 

51 :param additional_logging_values_dict: 

52 :param context_prefix: a prefix to add to all contexts that are created within the experiment. This can be used to add 

53 an identifier of a certain execution/run, such that the actual context name passed to `begin_context` can be concise (e.g. just model name). 

54 :param add_log_to_all_contexts: whether to enable in-memory logging and add the resulting log file to all tracking contexts that 

55 are generated for this experiment upon context exit (or process termination if it is not cleanly closed) 

56 """ 

57 mlflow.set_tracking_uri(tracking_uri) 

58 mlflow.set_experiment(experiment_name=experiment_name) 

59 super().__init__(context_prefix=context_prefix, additional_logging_values_dict=additional_logging_values_dict) 

60 self._run_name_to_id = {} 

61 self.add_log_to_all_contexts = add_log_to_all_contexts 

62 if self.add_log_to_all_contexts: 

63 logging.add_memory_logger() 

64 

65 def _track_values(self, values_dict: Dict[str, Any]): 

66 with mlflow.start_run(): 

67 mlflow.log_metrics(values_dict) 

68 

69 def _create_tracking_context(self, name: str, description: str) -> MLFlowTrackingContext: 

70 run_id = self._run_name_to_id.get(name) 

71 print(f"create {name}") 

72 context = MLFlowTrackingContext(name, self, run_id=run_id, description=description) 

73 self._run_name_to_id[name] = context.run.info.run_id 

74 return context 

75 

76 def begin_context_for_model(self, model: VectorModelBase): 

77 context = super().begin_context_for_model(model) 

78 context.track_tag("ModelClass", model.__class__.__name__) 

79 return context 

80 

81 def end_context(self, instance: MLFlowTrackingContext): 

82 print(f"end {instance}") 

83 if self.add_log_to_all_contexts: 

84 instance.track_text("log", logging.get_memory_log()) 

85 super().end_context(instance)