Source code for chariot.training_v2.run

"""Training run management."""

import math
from datetime import datetime, timedelta
from typing import Literal

from pydantic import StrictFloat, StrictInt, StrictStr

from chariot import _apis
from chariot.training_v2._common import (
    BaseModelWithDatetime,
    _datetime_to_epoch,
    _get_epoch_validator,
)
from chariot.training_v2.checkpoint import Checkpoint, get_checkpoints
from chariot.training_v2.exceptions import RunDoesNotExistError, ValidationError
from chariot_api._openapi.training_v2 import models as training_v2_openapi_models
from chariot_api._openapi.training_v2.exceptions import ApiException
from chariot_api._openapi.training_v2.models.field_errors_response import (
    FieldErrorsResponse,
)

__all__ = [
    "create_run",
    "get_runs",
    "validate_run_config",
    "Event",
    "Gpu",
    "Metric",
    "Progress",
    "Resources",
    "Run",
    "STATUSES",
]

STATUSES = {
    "run_created",
    "run_stop_requested",
    "run_restart_requested",
    "job_create_failed",
    "job_created",
    "job_submitted",
    "job_pending",
    "job_running",
    "job_terminate_requested",
    "job_terminating",
    "job_terminated",
    "job_failed",
    "job_completed",
    "job_unknown",
}


[docs] class Metric(BaseModelWithDatetime): """Training run metric.""" id: StrictStr created_at: datetime _created_at_validator: classmethod = _get_epoch_validator("created_at") run_id: StrictStr global_step: StrictInt tag: StrictStr value: StrictFloat | StrictInt job_id: StrictStr | None = None
def _openapi_metric_to_metric(openapi_metric) -> Metric: metric_dict = openapi_metric.model_dump() if metric_dict["value"] is None: metric_dict["value"] = math.nan return Metric.model_validate(metric_dict)
[docs] class Gpu(BaseModelWithDatetime): """Gpu resource metadata. All available gpu types can be found be calling the function ``chariot.system_resources.get_available_system_gpus``. """ count: StrictInt type: StrictStr
[docs] class Resources(BaseModelWithDatetime): """Training run scheduling resources. These values represent kubernetes resources that will be allocated for a training run. Example values: cpu: "1" cpu: "500m" memory: "5Gi" # gigabytes memory: "5000000Ki" # kilobytes ephemeral_storage: "20Gi" Reference: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#resource-units-in-kubernetes """ cpu: StrictStr memory: StrictStr ephemeral_storage: StrictStr | None = None gpu: Gpu | None = None
[docs] class Progress(BaseModelWithDatetime): """Training run progress.""" operation: StrictStr value: StrictFloat | StrictInt final_value: StrictFloat | StrictInt units: StrictStr
[docs] class Event(BaseModelWithDatetime): """Training run event.""" id: StrictStr sequence: StrictInt run_id: StrictStr created_at: datetime _created_at_validator: classmethod = _get_epoch_validator("created_at") status: StrictStr details: dict
[docs] class Run(BaseModelWithDatetime): """Training run. Please use :func:`chariot.training_v2.run.Run.from_id` to get a run by id, or :func:`chariot.training_v2.run.get_runs` to lookup runs by name, version, etc. Fields marked Optional should be included by default, but might be missing if a ``select`` filter is applied to :func:`chariot.training_v2.run.get_runs`. """ id: StrictStr | None = None name: StrictStr | None = None version: StrictStr | None = None created_at: datetime | None = None _created_at_validator: classmethod = _get_epoch_validator("created_at") blueprint_id: StrictStr | None = None project_id: StrictStr | None = None user_id: StrictStr | None = None progress: list[Progress] | None = None progress_updated_at: datetime | None = None _progress_updated_at_validator: classmethod = _get_epoch_validator("progress_updated_at") status: StrictStr | None = None status_updated_at: datetime | None = None _status_updated_at_validator: classmethod = _get_epoch_validator("status_updated_at") task_type: StrictStr | None = None resources: Resources | None = None config: dict | None = None notes: str | None = None def _id(self) -> str: if self.id is not None: return self.id if self.name and self.version and self.project_id: runs = get_runs( name_ilikes=[self.name], versions=[self.version], project_ids=[self.project_id], select=["id"], ) if not runs: raise ValueError( f"no run found with name '{self.name}', version '{self.version}', and project_id '{self.project_id}'" ) if len(runs) > 1: raise ValueError( f"multiple runs found with name '{self.name}', version '{self.version}', and project_id '{self.project_id}'" ) run_id = runs[0].id assert run_id is not None self.id = run_id return run_id raise ValueError( "the 'id' field is missing on this 'Run' object; please use `Run.from_id` or `get_runs` to lookup the run" )
[docs] def reload(self, fields: list[str] | None = None): """Reload the training run. Parameters ---------- fields: list[str] | None List of fields to reload. Options are "status", "status_updated_at", "progress", "progress_updated_at", and "notes". If omitted, all fields will be refreshed. Raises ------ RunDoesNotExistError If the run does not exist or has been deleted, this will be raised. ValueError If `fields` are invalid, or the `id` is not set on this run. APIException If api communication fails, request is unauthorized or is unauthenticated. """ _valid_reload_fields = { "status", "status_updated_at", "progress", "progress_updated_at", "notes", } if not fields: fields = list(_valid_reload_fields) else: for field in fields: if field not in _valid_reload_fields: raise ValueError( f"encountered an invalid field: {field}. valid fields are {_valid_reload_fields}" ) runs = _apis.training_v2.v2_runs_api.v2_runs_get(id=[self._id()], select=fields).data if len(runs) == 0: raise RunDoesNotExistError(run_id=self._id()) run = runs[0] if "status" in fields: self.status = run.status if "status_updated_at" in fields: self.status_updated_at = run.status_updated_at if "progress" in fields: self.progress = [Progress(**p.model_dump()) for p in run.progress] if "progress_updated_at" in fields: self.progress_updated_at = run.progress_updated_at if "notes" in fields: self.notes = run.notes
[docs] def get_global_steps_with_checkpoints(self) -> list[int]: """Get the global steps for which a checkpoint exists for this run. Returns ------- list[int]: global steps Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ checkpoints = self.get_checkpoints(select=["global_step"]) return [checkpoint.global_step for checkpoint in checkpoints]
[docs] def delete(self): """Delete this run. Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ _apis.training_v2.v2_runs_api.v2_runs_id_delete(id=self._id())
[docs] def get_metrics( self, global_steps: list[int] | None = None, tags: list[str] | None = None, limit: int = 1000, created_before: datetime | None = None, ) -> list[Metric]: """Get metrics for this run. Parameters ---------- global_steps: list[int] | None if specified, only return metrics for these global steps tags: list[str] | None if specified, only return metrics with the given tags limit: int (default: 1000) limit the response to the given number of metrics created_before: datetime | None if specified, filter to metrics created before the given date and time. This can be used for keyset pagination Returns ------- metrics: list[Metric] Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ response = _apis.training_v2.v2_metrics_api.v2_metrics_get( run_id=[self._id()], global_step=global_steps, tag=tags, limit=limit, created_before=_datetime_to_epoch(created_before) if created_before else None, ) return [_openapi_metric_to_metric(metric) for metric in response.data]
[docs] def get_all_metrics(self) -> list[Metric]: """Get all metrics for this run. Sort order is unspecified and may change in the future. Returns ------- metrics: list[Metric] Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ # TODO(s.maddox): once `id_after` is added to the API, just use that rather # than listing all tags and then all metrics for each tag. tags = [] while True: new_tags = _apis.training_v2.v2_metrics_api.v2_metrics_tags_get( run_id=[self._id()], tag_after=tags[-1] if tags else None, ).data if not new_tags: break tags.extend(new_tags) metrics = [] for tag in tags: global_step_after = None while True: new_metrics = _apis.training_v2.v2_metrics_api.v2_metrics_get( run_id=[self._id()], tag=[tag], global_step_after=global_step_after, sort=["global_step:asc"], limit=10000, ).data if not new_metrics: break global_step_after = new_metrics[-1].global_step metrics.extend( [_openapi_metric_to_metric(new_metrics) for new_metrics in new_metrics] ) # NOTE(s.maddox): sorting by id:asc to be forward compatible with id_after pagination metrics.sort(key=lambda metric: metric.id) return metrics
[docs] def get_events( self, limit: int | None = None, offset: int | None = None, sort: list[Literal["sequence:desc", "sequence:asc"]] | None = None, ) -> list[Event]: """Get events for this run. Parameters ---------- limit: int | None Limit the response to the given number of run events. Defaults to 100. offset: int | None Offset based pagination. Defaults to 0. sort: list[Literal["sequence:desc", "sequence:asc"]] | None Sort by the given fields in the given directions. The field and direction should be separated by a colon. For example: ``sort=sequence:desc``. Defaults to ``sequence:desc``. Valid field is only ``sequence``. Valid directions are ascending (``asc``) or descending (``desc``). If the direction is not specified it defaults to ascending (``asc``). Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ events = _apis.training_v2.v2_runs_api.v2_runs_id_events_get( id=self._id(), limit=limit, offset=offset, sort=sort, ).data return [Event.model_validate(event) for event in events]
[docs] def get_checkpoints( self, *, ids: list[str] | None = None, project_ids: list[str] | None = None, global_steps: list[int] | None = None, statuses: list[Literal["incomplete", "complete"]] | None = None, created_before: datetime | None = None, created_after: datetime | None = None, select: list[ Literal[ "id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at", ] ] | None = None, sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None, limit: int | None = None, offset: int | None = None, ) -> list[Checkpoint]: """Get checkpoints for this run. Parameters ---------- ids: list[str] | None If specified, filter to checkpoints with any of the given IDs project_ids: list[str] | None If specified, filter to checkpoints with any of the given Project Ids global_steps: list[int] | None If specified, filter to checkpoints from any of the given Global Steps statuses: list[Literal["incomplete", "complete"]] | None If specified, filter to checkpoints with a specific status. created_before: datetime | None If specified, filter to checkpoints created before the given date and time. This can be used for keyset pagination created_after: datetime | None If specified, filter to checkpoints created after the given date and time. This can be used for keyset pagination select: list[Literal[ "id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at"]] | None If specified, only the given fields are included in the response. sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None Sort by the given fields in the given directions. Default: `"created_at:desc"` limit: int | None Limit the response to the given number of checkpoints. Default: 10 offset: int | None Offset based pagination. Default: 0 Returns ------- checkpoints: list[Checkpoint] The checkpoints matching the filter criteria Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return get_checkpoints( run_ids=[self._id()], ids=ids, project_ids=project_ids, global_steps=global_steps, statuses=statuses, created_before=created_before, created_after=created_after, select=select, sort=sort, limit=limit, offset=offset, )
[docs] def stop(self, grace_period: timedelta | None = None) -> None: """Stop the training run. Parameters ---------- grace_period: timedelta | None Time that will be tolerated before the run should be force stopped. Must be greater than or equal to 1 second. If not provided, will default to 10 minutes Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ _apis.training_v2.v2_runs_api.v2_runs_id_stop_post( id=self._id(), grace_period_seconds=int(grace_period.total_seconds()) if grace_period else None, )
[docs] def restart(self, resources: Resources = None) -> None: request = training_v2_openapi_models.RestartRunRequest() if resources: request.resources = training_v2_openapi_models.Resources( cpu=resources.cpu, ephemeral_storage=resources.ephemeral_storage, memory=resources.memory, gpu=None if not resources.gpu else training_v2_openapi_models.GPU( type=resources.gpu.type, count=resources.gpu.count, ), ) _apis.training_v2.v2_runs_api.v2_runs_id_restart_post(self._id(), request)
[docs] @classmethod def from_id(self, run_id: str) -> "Run": """Get a training run by id.""" run = _apis.training_v2.v2_runs_api.v2_runs_id_get(id=run_id).data return Run.model_validate(run.model_dump(exclude_none=True))
[docs] def create_run( *, name: str, version: str, resources: Resources, config: dict, task_type: str, blueprint_id: str, project_id: str, notes: str | None = None, ) -> str: """Create a training run. Parameters ---------- name: str name of the run version: str version of the run resources: Resources resources to allocate for scheduling the run config: dict the run config task_type: str task type of the run project_id: str the id of the project to create the run in. To lookup a project id by name, use ``chariot.projects.get_project_id``. blueprint_id: str the id of the blueprint to use. To lookup a blueprint id by name, use ``chariot.training_v2.lookup_blueprint_id``. notes: str, optional notes associated with the training run Returns ------- run_id: str the created run's id Raises ------ ValidationError if the provided run config is invalid according to the blueprint, or any required parameters are ill-formed. APIException if api communication fails, request is unauthorized or is unauthenticated. """ # TODO(ZachDougherty): add more documentation for run config body = training_v2_openapi_models.SubmitRunRequest( blueprint_id=blueprint_id, config=config, name=name, project_id=project_id, resources=training_v2_openapi_models.Resources( cpu=resources.cpu, ephemeral_storage=resources.ephemeral_storage, memory=resources.memory, gpu=None if not resources.gpu else training_v2_openapi_models.GPU( type=resources.gpu.type, count=resources.gpu.count, ), ), task_type=task_type, version=version, notes=notes, ) try: run_id = _apis.training_v2.v2_runs_api.v2_runs_post( body=body, ).data.id except ApiException as e: if e.status == 422: field_errors_response_form_dict = FieldErrorsResponse.from_json(e.body) raise ValidationError(field_errors_response_form_dict.errors) from None else: raise e return run_id
[docs] def get_runs( *, blueprint_ids: list[str] | None = None, created_after: datetime | None = None, created_before: datetime | None = None, ids: list[str] | None = None, id_after: str | None = None, limit: int | None = None, offset: int | None = None, name_ilikes: list[str] | None = None, project_ids: list[str] | None = None, select: list[ Literal[ "*", "id", "project_id", "user_id", "created_at", "name", "version", "blueprint_id", "task_type", "config", "resources", "progress", "progress_updated_at", "status", "status_updated_at", ] ] | None = None, sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None, statuses: list[ Literal[ "run_created", "run_stop_requested", "run_restart_requested", "job_create_failed", "job_created", "job_submitted", "job_pending", "job_running", "job_terminate_requested", "job_terminating", "job_terminated", "job_failed", "job_completed", "job_unknown", ] ] | None = None, task_types: list[str] | None = None, versions: list[str] | None = None, user_ids: list[str] | None = None, ) -> list[Run]: """Get runs matching the provided critera Parameters ---------- blueprint_ids: list[str] | None If specified, filter to runs with any of the given Blueprint IDs created_after: datetime | None If specified, filter to runs created after the given date and time. This can be used for keyset pagination created_before: datetime | None If specified, filter to runs created before the given date and time. This can be used for keyset pagination ids: list[str] | None If specified, filter to runs with any of the given IDs. id_after: str | None If specified, filter to runs with an ID after the given ID. This can be used for keyset pagination. limit: int | None Limit the response to the given number of runs. Default: 10 offset: int | None Offset based pagination. Default: 0 name_ilikes: list[str] | None If specified, filter to runs with a ``name`` that matches any of the given SQL ILIKE patterns. Options for pattern matching are: ``%`` matches any sequence of zero or more characters. ``_`` matches any single character. To match the literal characters ``%`` or ``_``, escape the character with a ``\\``, e.g. ``\\%testrun`` To use equality matching, simply provide a plain string with no special characters. Matching is case insensitive. For example: The pattern ``%test-run%`` matches ``test-run``, ``FOOtest-runBAR``, and ``test-runBAR``. The pattern ``\\%test_run`` matches ``%test9run`` and ``%test_run``, but not ``FOOtest_run``, ``%test__run``, or ``%test_runBAR``. The pattern ``test\\_run`` matches ``test_run`` and nothing else. project_ids: list[str] | None If specified, filter to runs with any of the given Project IDs. To lookup a project id by name, use ``chariot.projects.get_project_id`` select: list[Literal["id", "project_id", "user_id", "created_at", "name", "version", \ "blueprint_id", "task_type", "config", "resources", "progress", "progress_updated_at", \ "status", "status_updated_at"]]] | None If specified, only the selected fields are included in the response. If all fields are desired, use "*". Excluded attributes will be `None` in the ``chariot.training_v2.Run`` responses. sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]]] | None Sort by the given fields in the given directions. The field and direction should be separated by a colon. Default: ``"created_at:desc"`` statuses: list[Literal["run_created", "run_stop_requested", "run_restart_requested", \ "job_create_failed", "job_created", "job_submitted", "job_pending", "job_running", \ "job_terminate_requested", "job_terminating", "job_terminated", "job_failed", \ "job_completed", "job_unknown"]] | None If specified, filter to runs with any of the given statuses. task_types: list[str] | None If specified, filter to runs with any of the given Task Types. Examples: ``"Object Detection"``, ``"Image Segmentation"`` versions: list[str] | None If specified, filter to runs with any of the given Versions. user_ids: list[str] | None If specified, filter to runs with any of the given User IDs. Returns ------- list[Run] Runs matching the filter criteria Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return [ Run.model_validate(run.model_dump(exclude_none=True)) for run in _apis.training_v2.v2_runs_api.v2_runs_get( blueprint_id=blueprint_ids, created_after=_datetime_to_epoch(created_after) if created_after else None, created_before=_datetime_to_epoch(created_before) if created_before else None, id=ids, id_after=id_after, limit=limit, name_ilike=name_ilikes, offset=offset, project_id=project_ids, select=select, sort=sort, status=statuses, task_type=task_types, version=versions, user_id=user_ids, ).data ]
[docs] def validate_run_config( *, blueprint_id: str, config: dict, ): """Validate a training run configuration against the provided blueprint id. Parameters ---------- blueprint_id: str The blueprint to validate against config: dict The run configuration to validate Raises ------ ValidationError if the provided run config is invalid APIException if api communication fails, request is unauthorized or is unauthenticated. """ body = training_v2_openapi_models.ValidateBlueprintRequest( config=config, ) try: _apis.training_v2.v2_blueprints_api.v2_blueprints_id_validate_post_with_http_info( id=blueprint_id, body=body, ) except ApiException as e: if e.status == 422: field_errors_response_form_dict = FieldErrorsResponse.from_json(e.body) raise ValidationError(field_errors_response_form_dict.errors) from None else: raise e