Coverage for src/sensai/tracking/mlflow_tracking.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-29 18:29 +0000

1import re 

2from typing import Dict, Any, Optional 

3 

4import mlflow 

5from matplotlib import pyplot as plt 

6 

7from .tracking_base import TrackedExperiment, TrackingContext 

8from .. import VectorModelBase 

9from ..util import logging 

10 

11 

12class MLFlowTrackingContext(TrackingContext): 

13 def __init__(self, name: str, experiment: "MLFlowExperiment", run_id=None, description=""): 

14 super().__init__(name, experiment) 

15 if run_id is not None: 

16 run = mlflow.start_run(run_id) 

17 else: 

18 run = mlflow.start_run(run_name=name, description=description) 

19 self.run = run 

20 self.log_handler: Optional[logging.MemoryStreamHandler] = None 

21 

22 @staticmethod 

23 def _metric_name(name: str): 

24 result = re.sub(r"\[(.*?)\]", r"_\1", name) # replace "foo[bar]" with "foo_bar" 

25 result = re.sub(r"[^a-zA-Z0-9-_. /]+", "_", result) # replace sequences of unsupported chars with underscore 

26 return result 

27 

28 def _track_metrics(self, metrics: Dict[str, float]): 

29 metrics = {self._metric_name(name): value for name, value in metrics.items()} 

30 mlflow.log_metrics(metrics) 

31 

32 def track_figure(self, name: str, fig: plt.Figure): 

33 mlflow.log_figure(fig, name + ".png") 

34 

35 def track_text(self, name: str, content: str): 

36 mlflow.log_text(content, name + ".txt") 

37 

38 def track_tag(self, tag_name: str, tag_value: str): 

39 mlflow.set_tag(tag_name, tag_value) 

40 

41 def _end(self): 

42 mlflow.end_run() 

43 

44 

45class MLFlowExperiment(TrackedExperiment[MLFlowTrackingContext]): 

46 def __init__(self, experiment_name: str, tracking_uri: str, additional_logging_values_dict=None, 

47 context_prefix: str = "", add_log_to_all_contexts=False): 

48 """ 

49 :param experiment_name: the name of the experiment, which should be the same for all models of the same kind (i.e. all models evaluated 

50 under the same conditions) 

51 :param tracking_uri: the URI of the server (if any); use "" to track in the local file system 

52 :param additional_logging_values_dict: 

53 :param context_prefix: a prefix to add to all contexts that are created within the experiment. This can be used to add 

54 an identifier of a certain execution/run, such that the actual context name passed to `begin_context` can be concise (e.g. just model name). 

55 :param add_log_to_all_contexts: whether to enable in-memory logging and add the respective log to each context 

56 """ 

57 mlflow.set_tracking_uri(tracking_uri) 

58 mlflow.set_experiment(experiment_name=experiment_name) 

59 super().__init__(context_prefix=context_prefix, additional_logging_values_dict=additional_logging_values_dict) 

60 self._run_name_to_id = {} 

61 self.add_log_to_all_contexts = add_log_to_all_contexts 

62 

63 def _track_values(self, values_dict: Dict[str, Any]): 

64 with mlflow.start_run(): 

65 mlflow.log_metrics(values_dict) 

66 

67 def _create_tracking_context(self, name: str, description: str) -> MLFlowTrackingContext: 

68 run_id = self._run_name_to_id.get(name) 

69 print(f"create {name}") 

70 context = MLFlowTrackingContext(name, self, run_id=run_id, description=description) 

71 self._run_name_to_id[name] = context.run.info.run_id 

72 return context 

73 

74 def begin_context_for_model(self, model: VectorModelBase): 

75 context = super().begin_context_for_model(model) 

76 if self.add_log_to_all_contexts: 

77 context.log_handler = logging.add_memory_logger() 

78 context.track_tag("ModelClass", model.__class__.__name__) 

79 return context 

80 

81 def end_context(self, instance: MLFlowTrackingContext): 

82 print(f"end {instance}") 

83 if instance.log_handler is not None: 

84 instance.track_text("log", instance.log_handler.get_log()) 

85 super().end_context(instance)