Coverage for src/sensai/tracking/clearml_tracking.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-29 18:29 +0000
1import logging
2from typing import Dict
4from matplotlib import pyplot as plt
6from .tracking_base import TrackingContext, TContext
7from ..tracking import TrackedExperiment
9from clearml import Task
11log = logging.getLogger(__name__)
14class ClearMLTrackingContext(TrackingContext):
15 def __init__(self, name, experiment, task: Task):
16 super().__init__(name, experiment)
17 self.task = task
19 def _track_metrics(self, metrics: Dict[str, float]):
20 self.task.connect(metrics)
22 def track_figure(self, name: str, fig: plt.Figure):
23 fig.show() # any shown figure is automatically tracked
25 def track_text(self, name: str, content: str):
26 # TODO upload_artifact might be more appropriate, but it seems to require saving to a file first. What's the best way to do this?
27 self.task.get_logger().report_text(content, print_console=False)
29 def _end(self):
30 pass
33# TODO: this is an initial working implementation, it should eventually be improved
34class ClearMLExperiment(TrackedExperiment):
35 def __init__(self, task: Task = None, project_name: str = None, task_name: str = None,
36 additional_logging_values_dict=None):
37 """
39 :param task: instances of trains.Task
40 :param project_name: only necessary if task is not provided
41 :param task_name: only necessary if task is not provided
42 :param additional_logging_values_dict:
43 """
44 if task is None:
45 if project_name is None or task_name is None:
46 raise ValueError("Either the trains task or the project name and task name have to be provided")
47 self.task = Task.init(project_name=project_name, task_name=task_name, reuse_last_task_id=False)
48 else:
49 if project_name is not None:
50 log.warning(
51 f"projectName parameter with value {project_name} passed even though task has been given, "
52 f"will ignore this parameter"
53 )
54 if task_name is not None:
55 log.warning(
56 f"taskName parameter with value {task_name} passed even though task has been given, "
57 f"will ignore this parameter"
58 )
59 self.task = task
60 self.logger = self.task.get_logger()
61 super().__init__(additional_logging_values_dict=additional_logging_values_dict)
63 def _track_values(self, values_dict):
64 self.task.connect(values_dict)
66 def _create_tracking_context(self, name: str, description: str) -> TContext:
67 return ClearMLTrackingContext(name, self, self.task)