Coverage for src/sensai/tracking/azure_tracking.py: 0%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-13 22:17 +0000

1from azureml.core import Experiment, Workspace 

2from typing import Dict, Any 

3 

4from .tracking_base import TrackedExperiment, TContext 

5from .. import VectorModel 

6from ..evaluation.evaluator import MetricsDictProvider 

7 

8 

9class TrackedAzureMLEvaluation: 

10 """ 

11 Class to automatically track parameters, metrics and artifacts for a single model with azureml-sdk 

12 """ 

13 def __init__(self, experiment_name: str, workspace: Workspace, 

14 evaluator: MetricsDictProvider): 

15 """ 

16 :param experiment_name: 

17 :param workspace: 

18 :param evaluator: 

19 """ 

20 self.experiment_name = experiment_name 

21 self.evaluator = evaluator 

22 self.experiment = Experiment(workspace=workspace, name=experiment_name) 

23 

24 def eval_model(self, model: VectorModel, additional_logging_values_dict: dict = None, **start_logging_kwargs): 

25 with self.experiment.start_logging(**start_logging_kwargs) as run: 

26 values_dict = self.evaluator.compute_metrics(model) 

27 values_dict['str(model)'] = str(model) 

28 if additional_logging_values_dict is not None: 

29 values_dict.update(additional_logging_values_dict) 

30 for name, value in values_dict.items(): 

31 run.log(name, value) 

32 

33 

34class TrackedAzureMLExperiment(TrackedExperiment): 

35 def __init__(self, experiment_name: str, workspace: Workspace, additional_logging_values_dict=None): 

36 """ 

37 

38 :param experiment_name: name of experiment for tracking in workspace 

39 :param workspace: Azure workspace object 

40 :param additional_logging_values_dict: additional values to be logged for each run 

41 """ 

42 self.experiment_name = experiment_name 

43 self.experiment = Experiment(workspace=workspace, name=experiment_name) 

44 super().__init__(additional_logging_values_dict=additional_logging_values_dict) 

45 

46 def _track_values(self, values_dict: Dict[str, Any]): 

47 with self.experiment.start_logging() as run: 

48 for name, value in values_dict.items(): 

49 run.log(name, value) 

50 

51 def _create_tracking_context(self, name: str, description: str) -> TContext: 

52 raise NotImplementedError()