Coverage for src/sensai/tracking/azure_tracking.py: 0%
29 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-13 22:17 +0000
1from azureml.core import Experiment, Workspace
2from typing import Dict, Any
4from .tracking_base import TrackedExperiment, TContext
5from .. import VectorModel
6from ..evaluation.evaluator import MetricsDictProvider
9class TrackedAzureMLEvaluation:
10 """
11 Class to automatically track parameters, metrics and artifacts for a single model with azureml-sdk
12 """
13 def __init__(self, experiment_name: str, workspace: Workspace,
14 evaluator: MetricsDictProvider):
15 """
16 :param experiment_name:
17 :param workspace:
18 :param evaluator:
19 """
20 self.experiment_name = experiment_name
21 self.evaluator = evaluator
22 self.experiment = Experiment(workspace=workspace, name=experiment_name)
24 def eval_model(self, model: VectorModel, additional_logging_values_dict: dict = None, **start_logging_kwargs):
25 with self.experiment.start_logging(**start_logging_kwargs) as run:
26 values_dict = self.evaluator.compute_metrics(model)
27 values_dict['str(model)'] = str(model)
28 if additional_logging_values_dict is not None:
29 values_dict.update(additional_logging_values_dict)
30 for name, value in values_dict.items():
31 run.log(name, value)
34class TrackedAzureMLExperiment(TrackedExperiment):
35 def __init__(self, experiment_name: str, workspace: Workspace, additional_logging_values_dict=None):
36 """
38 :param experiment_name: name of experiment for tracking in workspace
39 :param workspace: Azure workspace object
40 :param additional_logging_values_dict: additional values to be logged for each run
41 """
42 self.experiment_name = experiment_name
43 self.experiment = Experiment(workspace=workspace, name=experiment_name)
44 super().__init__(additional_logging_values_dict=additional_logging_values_dict)
46 def _track_values(self, values_dict: Dict[str, Any]):
47 with self.experiment.start_logging() as run:
48 for name, value in values_dict.items():
49 run.log(name, value)
51 def _create_tracking_context(self, name: str, description: str) -> TContext:
52 raise NotImplementedError()