Source code for sensai.tracking.azure_tracking

from azureml.core import Experiment, Workspace
from typing import Dict, Any

from .tracking_base import TrackedExperiment, TContext
from .. import VectorModel
from ..evaluation.evaluator import MetricsDictProvider


[docs]class TrackedAzureMLEvaluation: """ Class to automatically track parameters, metrics and artifacts for a single model with azureml-sdk """ def __init__(self, experiment_name: str, workspace: Workspace, evaluator: MetricsDictProvider): """ :param experiment_name: :param workspace: :param evaluator: """ self.experiment_name = experiment_name self.evaluator = evaluator self.experiment = Experiment(workspace=workspace, name=experiment_name)
[docs] def eval_model(self, model: VectorModel, additional_logging_values_dict: dict = None, **start_logging_kwargs): with self.experiment.start_logging(**start_logging_kwargs) as run: values_dict = self.evaluator.compute_metrics(model) values_dict['str(model)'] = str(model) if additional_logging_values_dict is not None: values_dict.update(additional_logging_values_dict) for name, value in values_dict.items(): run.log(name, value)
[docs]class TrackedAzureMLExperiment(TrackedExperiment): def __init__(self, experiment_name: str, workspace: Workspace, additional_logging_values_dict=None): """ :param experiment_name: name of experiment for tracking in workspace :param workspace: Azure workspace object :param additional_logging_values_dict: additional values to be logged for each run """ self.experiment_name = experiment_name self.experiment = Experiment(workspace=workspace, name=experiment_name) super().__init__(additional_logging_values_dict=additional_logging_values_dict) def _track_values(self, values_dict: Dict[str, Any]): with self.experiment.start_logging() as run: for name, value in values_dict.items(): run.log(name, value) def _create_tracking_context(self, name: str, description: str) -> TContext: raise NotImplementedError()