Coverage for src/sensai/tracking/clearml_tracking.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1import logging 

2from typing import Dict 

3 

4from matplotlib import pyplot as plt 

5 

6from .tracking_base import TrackingContext, TContext 

7from ..tracking import TrackedExperiment 

8 

9from clearml import Task 

10 

11log = logging.getLogger(__name__) 

12 

13 

14class ClearMLTrackingContext(TrackingContext): 

15 def __init__(self, name, experiment, task: Task): 

16 super().__init__(name, experiment) 

17 self.task = task 

18 

19 def _track_metrics(self, metrics: Dict[str, float]): 

20 self.task.connect(metrics) 

21 

22 def track_figure(self, name: str, fig: plt.Figure): 

23 fig.show() # any shown figure is automatically tracked 

24 

25 def track_text(self, name: str, content: str): 

26 # TODO upload_artifact might be more appropriate, but it seems to require saving to a file first. What's the best way to do this? 

27 self.task.get_logger().report_text(content, print_console=False) 

28 

29 def _end(self): 

30 pass 

31 

32 

33# TODO: this is an initial working implementation, it should eventually be improved 

34class ClearMLExperiment(TrackedExperiment): 

35 def __init__(self, task: Task = None, project_name: str = None, task_name: str = None, 

36 additional_logging_values_dict=None): 

37 """ 

38 

39 :param task: instances of trains.Task 

40 :param project_name: only necessary if task is not provided 

41 :param task_name: only necessary if task is not provided 

42 :param additional_logging_values_dict: 

43 """ 

44 if task is None: 

45 if project_name is None or task_name is None: 

46 raise ValueError("Either the trains task or the project name and task name have to be provided") 

47 self.task = Task.init(project_name=project_name, task_name=task_name, reuse_last_task_id=False) 

48 else: 

49 if project_name is not None: 

50 log.warning( 

51 f"projectName parameter with value {project_name} passed even though task has been given, " 

52 f"will ignore this parameter" 

53 ) 

54 if task_name is not None: 

55 log.warning( 

56 f"taskName parameter with value {task_name} passed even though task has been given, " 

57 f"will ignore this parameter" 

58 ) 

59 self.task = task 

60 self.logger = self.task.get_logger() 

61 super().__init__(additional_logging_values_dict=additional_logging_values_dict) 

62 

63 def _track_values(self, values_dict): 

64 self.task.connect(values_dict) 

65 

66 def _create_tracking_context(self, name: str, description: str) -> TContext: 

67 return ClearMLTrackingContext(name, self, self.task)