Source code for sensai.tracking.clearml_tracking
import logging
from typing import Dict
from matplotlib import pyplot as plt
from .tracking_base import TrackingContext, TContext
from ..tracking import TrackedExperiment
from clearml import Task
log = logging.getLogger(__name__)
[docs]class ClearMLTrackingContext(TrackingContext):
def __init__(self, name, experiment, task: Task):
super().__init__(name, experiment)
self.task = task
def _track_metrics(self, metrics: Dict[str, float]):
self.task.connect(metrics)
[docs] def track_figure(self, name: str, fig: plt.Figure):
fig.show() # any shown figure is automatically tracked
[docs] def track_text(self, name: str, content: str):
# TODO upload_artifact might be more appropriate, but it seems to require saving to a file first. What's the best way to do this?
self.task.get_logger().report_text(content, print_console=False)
def _end(self):
pass
# TODO: this is an initial working implementation, it should eventually be improved
[docs]class ClearMLExperiment(TrackedExperiment):
def __init__(self, task: Task = None, project_name: str = None, task_name: str = None,
additional_logging_values_dict=None):
"""
:param task: instances of trains.Task
:param project_name: only necessary if task is not provided
:param task_name: only necessary if task is not provided
:param additional_logging_values_dict:
"""
if task is None:
if project_name is None or task_name is None:
raise ValueError("Either the trains task or the project name and task name have to be provided")
self.task = Task.init(project_name=project_name, task_name=task_name, reuse_last_task_id=False)
else:
if project_name is not None:
log.warning(
f"projectName parameter with value {project_name} passed even though task has been given, "
f"will ignore this parameter"
)
if task_name is not None:
log.warning(
f"taskName parameter with value {task_name} passed even though task has been given, "
f"will ignore this parameter"
)
self.task = task
self.logger = self.task.get_logger()
super().__init__(additional_logging_values_dict=additional_logging_values_dict)
def _track_values(self, values_dict):
self.task.connect(values_dict)
def _create_tracking_context(self, name: str, description: str) -> TContext:
return ClearMLTrackingContext(name, self, self.task)