"""Training run management."""
import math
from datetime import datetime, timedelta
from typing import Literal
from pydantic import StrictFloat, StrictInt, StrictStr
from chariot import _apis
from chariot.training_v2._common import (
BaseModelWithDatetime,
_datetime_to_epoch,
_get_epoch_validator,
)
from chariot.training_v2.checkpoint import Checkpoint, get_checkpoints
from chariot.training_v2.exceptions import RunDoesNotExistError, ValidationError
from chariot_api._openapi.training_v2 import models as training_v2_openapi_models
from chariot_api._openapi.training_v2.exceptions import ApiException
from chariot_api._openapi.training_v2.models.field_errors_response import (
FieldErrorsResponse,
)
__all__ = [
"create_run",
"get_runs",
"validate_run_config",
"Event",
"Gpu",
"Metric",
"Progress",
"Resources",
"Run",
"STATUSES",
]
STATUSES = {
"run_created",
"run_stop_requested",
"run_restart_requested",
"job_create_failed",
"job_created",
"job_submitted",
"job_pending",
"job_running",
"job_terminate_requested",
"job_terminating",
"job_terminated",
"job_failed",
"job_completed",
"job_unknown",
}
[docs]
class Metric(BaseModelWithDatetime):
"""Training run metric."""
id: StrictStr
created_at: datetime
_created_at_validator: classmethod = _get_epoch_validator("created_at")
run_id: StrictStr
global_step: StrictInt
tag: StrictStr
value: StrictFloat | StrictInt
job_id: StrictStr | None = None
def _openapi_metric_to_metric(openapi_metric) -> Metric:
metric_dict = openapi_metric.model_dump()
if metric_dict["value"] is None:
metric_dict["value"] = math.nan
return Metric.model_validate(metric_dict)
[docs]
class Gpu(BaseModelWithDatetime):
"""Gpu resource metadata.
All available gpu types can be found be calling the function
``chariot.system_resources.get_available_system_gpus``.
"""
count: StrictInt
type: StrictStr
[docs]
class Resources(BaseModelWithDatetime):
"""Training run scheduling resources.
These values represent kubernetes resources that will be allocated for a training run.
Example values:
cpu: "1"
cpu: "500m"
memory: "5Gi" # gigabytes
memory: "5000000Ki" # kilobytes
ephemeral_storage: "20Gi"
Reference: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#resource-units-in-kubernetes
"""
cpu: StrictStr
memory: StrictStr
ephemeral_storage: StrictStr | None = None
gpu: Gpu | None = None
[docs]
class Progress(BaseModelWithDatetime):
"""Training run progress."""
operation: StrictStr
value: StrictFloat | StrictInt
final_value: StrictFloat | StrictInt
units: StrictStr
[docs]
class Event(BaseModelWithDatetime):
"""Training run event."""
id: StrictStr
sequence: StrictInt
run_id: StrictStr
created_at: datetime
_created_at_validator: classmethod = _get_epoch_validator("created_at")
status: StrictStr
details: dict
[docs]
class Run(BaseModelWithDatetime):
"""Training run.
Please use :func:`chariot.training_v2.run.Run.from_id` to get a run by id,
or :func:`chariot.training_v2.run.get_runs` to lookup runs by name, version, etc.
Fields marked Optional should be included by default,
but might be missing if a ``select`` filter is applied to :func:`chariot.training_v2.run.get_runs`.
"""
id: StrictStr | None = None
name: StrictStr | None = None
version: StrictStr | None = None
created_at: datetime | None = None
_created_at_validator: classmethod = _get_epoch_validator("created_at")
blueprint_id: StrictStr | None = None
project_id: StrictStr | None = None
user_id: StrictStr | None = None
progress: list[Progress] | None = None
progress_updated_at: datetime | None = None
_progress_updated_at_validator: classmethod = _get_epoch_validator("progress_updated_at")
status: StrictStr | None = None
status_updated_at: datetime | None = None
_status_updated_at_validator: classmethod = _get_epoch_validator("status_updated_at")
task_type: StrictStr | None = None
resources: Resources | None = None
config: dict | None = None
notes: str | None = None
def _id(self) -> str:
if self.id is not None:
return self.id
if self.name and self.version and self.project_id:
runs = get_runs(
name_ilikes=[self.name],
versions=[self.version],
project_ids=[self.project_id],
select=["id"],
)
if not runs:
raise ValueError(
f"no run found with name '{self.name}', version '{self.version}', and project_id '{self.project_id}'"
)
if len(runs) > 1:
raise ValueError(
f"multiple runs found with name '{self.name}', version '{self.version}', and project_id '{self.project_id}'"
)
run_id = runs[0].id
assert run_id is not None
self.id = run_id
return run_id
raise ValueError(
"the 'id' field is missing on this 'Run' object; please use `Run.from_id` or `get_runs` to lookup the run"
)
[docs]
def reload(self, fields: list[str] | None = None):
"""Reload the training run.
Parameters
----------
fields: list[str] | None
List of fields to reload.
Options are "status", "status_updated_at", "progress", "progress_updated_at", and "notes".
If omitted, all fields will be refreshed.
Raises
------
RunDoesNotExistError
If the run does not exist or has been deleted, this will be raised.
ValueError
If `fields` are invalid, or the `id` is not set on this run.
APIException
If api communication fails, request is unauthorized or is unauthenticated.
"""
_valid_reload_fields = {
"status",
"status_updated_at",
"progress",
"progress_updated_at",
"notes",
}
if not fields:
fields = list(_valid_reload_fields)
else:
for field in fields:
if field not in _valid_reload_fields:
raise ValueError(
f"encountered an invalid field: {field}. valid fields are {_valid_reload_fields}"
)
runs = _apis.training_v2.v2_runs_api.v2_runs_get(id=[self._id()], select=fields).data
if len(runs) == 0:
raise RunDoesNotExistError(run_id=self._id())
run = runs[0]
if "status" in fields:
self.status = run.status
if "status_updated_at" in fields:
self.status_updated_at = run.status_updated_at
if "progress" in fields:
self.progress = [Progress(**p.model_dump()) for p in run.progress]
if "progress_updated_at" in fields:
self.progress_updated_at = run.progress_updated_at
if "notes" in fields:
self.notes = run.notes
[docs]
def get_global_steps_with_checkpoints(self) -> list[int]:
"""Get the global steps for which a checkpoint exists for this run.
Returns
-------
list[int]: global steps
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
checkpoints = self.get_checkpoints(select=["global_step"])
return [checkpoint.global_step for checkpoint in checkpoints]
[docs]
def delete(self):
"""Delete this run.
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
_apis.training_v2.v2_runs_api.v2_runs_id_delete(id=self._id())
[docs]
def get_metrics(
self,
global_steps: list[int] | None = None,
tags: list[str] | None = None,
limit: int = 1000,
created_before: datetime | None = None,
) -> list[Metric]:
"""Get metrics for this run.
Parameters
----------
global_steps: list[int] | None
if specified, only return metrics for these global steps
tags: list[str] | None
if specified, only return metrics with the given tags
limit: int (default: 1000)
limit the response to the given number of metrics
created_before: datetime | None
if specified, filter to metrics created before the given
date and time. This can be used for keyset pagination
Returns
-------
metrics: list[Metric]
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
response = _apis.training_v2.v2_metrics_api.v2_metrics_get(
run_id=[self._id()],
global_step=global_steps,
tag=tags,
limit=limit,
created_before=_datetime_to_epoch(created_before) if created_before else None,
)
return [_openapi_metric_to_metric(metric) for metric in response.data]
[docs]
def get_all_metrics(self) -> list[Metric]:
"""Get all metrics for this run.
Sort order is unspecified and may change in the future.
Returns
-------
metrics: list[Metric]
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
# TODO(s.maddox): once `id_after` is added to the API, just use that rather
# than listing all tags and then all metrics for each tag.
tags = []
while True:
new_tags = _apis.training_v2.v2_metrics_api.v2_metrics_tags_get(
run_id=[self._id()],
tag_after=tags[-1] if tags else None,
).data
if not new_tags:
break
tags.extend(new_tags)
metrics = []
for tag in tags:
global_step_after = None
while True:
new_metrics = _apis.training_v2.v2_metrics_api.v2_metrics_get(
run_id=[self._id()],
tag=[tag],
global_step_after=global_step_after,
sort=["global_step:asc"],
limit=10000,
).data
if not new_metrics:
break
global_step_after = new_metrics[-1].global_step
metrics.extend(
[_openapi_metric_to_metric(new_metrics) for new_metrics in new_metrics]
)
# NOTE(s.maddox): sorting by id:asc to be forward compatible with id_after pagination
metrics.sort(key=lambda metric: metric.id)
return metrics
[docs]
def get_events(
self,
limit: int | None = None,
offset: int | None = None,
sort: list[Literal["sequence:desc", "sequence:asc"]] | None = None,
) -> list[Event]:
"""Get events for this run.
Parameters
----------
limit: int | None
Limit the response to the given number of run events. Defaults to 100.
offset: int | None
Offset based pagination. Defaults to 0.
sort: list[Literal["sequence:desc", "sequence:asc"]] | None
Sort by the given fields in the given directions. The field and direction should be
separated by a colon. For example: ``sort=sequence:desc``.
Defaults to ``sequence:desc``.
Valid field is only ``sequence``.
Valid directions are ascending (``asc``) or descending (``desc``).
If the direction is not specified it defaults to ascending (``asc``).
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
events = _apis.training_v2.v2_runs_api.v2_runs_id_events_get(
id=self._id(),
limit=limit,
offset=offset,
sort=sort,
).data
return [Event.model_validate(event) for event in events]
[docs]
def get_checkpoints(
self,
*,
ids: list[str] | None = None,
project_ids: list[str] | None = None,
global_steps: list[int] | None = None,
statuses: list[Literal["incomplete", "complete"]] | None = None,
created_before: datetime | None = None,
created_after: datetime | None = None,
select: list[
Literal[
"id",
"created_at",
"run_id",
"project_id",
"global_step",
"status",
"status_updated_at",
]
]
| None = None,
sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[Checkpoint]:
"""Get checkpoints for this run.
Parameters
----------
ids: list[str] | None
If specified, filter to checkpoints with any of the given IDs
project_ids: list[str] | None
If specified, filter to checkpoints with any of the given Project Ids
global_steps: list[int] | None
If specified, filter to checkpoints from any of the given
Global Steps
statuses: list[Literal["incomplete", "complete"]] | None
If specified, filter to checkpoints with a specific status.
created_before: datetime | None
If specified, filter to checkpoints created before the
given date and time. This can be used for keyset pagination
created_after: datetime | None
If specified, filter to checkpoints created after the
given date and time. This can be used for keyset pagination
select: list[Literal[ "id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at"]] | None
If specified, only the given fields are included in the response.
sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None
Sort by the given fields in the given directions.
Default: `"created_at:desc"`
limit: int | None
Limit the response to the given number of checkpoints. Default: 10
offset: int | None
Offset based pagination. Default: 0
Returns
-------
checkpoints: list[Checkpoint]
The checkpoints matching the filter criteria
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
return get_checkpoints(
run_ids=[self._id()],
ids=ids,
project_ids=project_ids,
global_steps=global_steps,
statuses=statuses,
created_before=created_before,
created_after=created_after,
select=select,
sort=sort,
limit=limit,
offset=offset,
)
[docs]
def stop(self, grace_period: timedelta | None = None) -> None:
"""Stop the training run.
Parameters
----------
grace_period: timedelta | None
Time that will be tolerated before the run should be force stopped.
Must be greater than or equal to 1 second. If not provided,
will default to 10 minutes
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
_apis.training_v2.v2_runs_api.v2_runs_id_stop_post(
id=self._id(),
grace_period_seconds=int(grace_period.total_seconds()) if grace_period else None,
)
[docs]
def restart(self, resources: Resources = None) -> None:
request = training_v2_openapi_models.RestartRunRequest()
if resources:
request.resources = training_v2_openapi_models.Resources(
cpu=resources.cpu,
ephemeral_storage=resources.ephemeral_storage,
memory=resources.memory,
gpu=None
if not resources.gpu
else training_v2_openapi_models.GPU(
type=resources.gpu.type,
count=resources.gpu.count,
),
)
_apis.training_v2.v2_runs_api.v2_runs_id_restart_post(self._id(), request)
[docs]
@classmethod
def from_id(self, run_id: str) -> "Run":
"""Get a training run by id."""
run = _apis.training_v2.v2_runs_api.v2_runs_id_get(id=run_id).data
return Run.model_validate(run.model_dump(exclude_none=True))
[docs]
def create_run(
*,
name: str,
version: str,
resources: Resources,
config: dict,
task_type: str,
blueprint_id: str,
project_id: str,
notes: str | None = None,
) -> str:
"""Create a training run.
Parameters
----------
name: str
name of the run
version: str
version of the run
resources: Resources
resources to allocate for scheduling the run
config: dict
the run config
task_type: str
task type of the run
project_id: str
the id of the project to create the run in.
To lookup a project id by name, use
``chariot.projects.get_project_id``.
blueprint_id: str
the id of the blueprint to use.
To lookup a blueprint id by name, use
``chariot.training_v2.lookup_blueprint_id``.
notes: str, optional
notes associated with the training run
Returns
-------
run_id: str
the created run's id
Raises
------
ValidationError
if the provided run config is invalid according to the blueprint,
or any required parameters are ill-formed.
APIException
if api communication fails, request is unauthorized or is unauthenticated.
"""
# TODO(ZachDougherty): add more documentation for run config
body = training_v2_openapi_models.SubmitRunRequest(
blueprint_id=blueprint_id,
config=config,
name=name,
project_id=project_id,
resources=training_v2_openapi_models.Resources(
cpu=resources.cpu,
ephemeral_storage=resources.ephemeral_storage,
memory=resources.memory,
gpu=None
if not resources.gpu
else training_v2_openapi_models.GPU(
type=resources.gpu.type,
count=resources.gpu.count,
),
),
task_type=task_type,
version=version,
notes=notes,
)
try:
run_id = _apis.training_v2.v2_runs_api.v2_runs_post(
body=body,
).data.id
except ApiException as e:
if e.status == 422:
field_errors_response_form_dict = FieldErrorsResponse.from_json(e.body)
raise ValidationError(field_errors_response_form_dict.errors) from None
else:
raise e
return run_id
[docs]
def get_runs(
*,
blueprint_ids: list[str] | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
ids: list[str] | None = None,
id_after: str | None = None,
limit: int | None = None,
offset: int | None = None,
name_ilikes: list[str] | None = None,
project_ids: list[str] | None = None,
select: list[
Literal[
"*",
"id",
"project_id",
"user_id",
"created_at",
"name",
"version",
"blueprint_id",
"task_type",
"config",
"resources",
"progress",
"progress_updated_at",
"status",
"status_updated_at",
]
]
| None = None,
sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None,
statuses: list[
Literal[
"run_created",
"run_stop_requested",
"run_restart_requested",
"job_create_failed",
"job_created",
"job_submitted",
"job_pending",
"job_running",
"job_terminate_requested",
"job_terminating",
"job_terminated",
"job_failed",
"job_completed",
"job_unknown",
]
]
| None = None,
task_types: list[str] | None = None,
versions: list[str] | None = None,
user_ids: list[str] | None = None,
) -> list[Run]:
"""Get runs matching the provided critera
Parameters
----------
blueprint_ids: list[str] | None
If specified, filter to runs with any of the given Blueprint IDs
created_after: datetime | None
If specified, filter to runs created after the given date and time.
This can be used for keyset pagination
created_before: datetime | None
If specified, filter to runs created before the given date and time.
This can be used for keyset pagination
ids: list[str] | None
If specified, filter to runs with any of the given IDs.
id_after: str | None
If specified, filter to runs with an ID after the given ID. This can be used for keyset pagination.
limit: int | None
Limit the response to the given number of runs. Default: 10
offset: int | None
Offset based pagination. Default: 0
name_ilikes: list[str] | None
If specified, filter to runs with a ``name`` that matches any of the given SQL ILIKE patterns.
Options for pattern matching are:
``%`` matches any sequence of zero or more characters.
``_`` matches any single character.
To match the literal characters ``%`` or ``_``,
escape the character with a ``\\``, e.g. ``\\%testrun``
To use equality matching, simply provide a plain string with
no special characters. Matching is case insensitive.
For example:
The pattern ``%test-run%`` matches ``test-run``, ``FOOtest-runBAR``, and ``test-runBAR``.
The pattern ``\\%test_run`` matches ``%test9run`` and ``%test_run``, but not
``FOOtest_run``, ``%test__run``, or ``%test_runBAR``.
The pattern ``test\\_run`` matches ``test_run`` and nothing else.
project_ids: list[str] | None
If specified, filter to runs with any of the given Project IDs.
To lookup a project id by name, use
``chariot.projects.get_project_id``
select: list[Literal["id", "project_id", "user_id", "created_at", "name", "version", \
"blueprint_id", "task_type", "config", "resources", "progress", "progress_updated_at", \
"status", "status_updated_at"]]] | None
If specified, only the selected fields are included in the response. If all
fields are desired, use "*". Excluded attributes will be `None` in the
``chariot.training_v2.Run`` responses.
sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]]] | None
Sort by the given fields in the given directions.
The field and direction should be separated by a colon.
Default: ``"created_at:desc"``
statuses: list[Literal["run_created", "run_stop_requested", "run_restart_requested", \
"job_create_failed", "job_created", "job_submitted", "job_pending", "job_running", \
"job_terminate_requested", "job_terminating", "job_terminated", "job_failed", \
"job_completed", "job_unknown"]] | None
If specified, filter to runs with any of the given statuses.
task_types: list[str] | None
If specified, filter to runs with any of the given Task Types.
Examples: ``"Object Detection"``, ``"Image Segmentation"``
versions: list[str] | None
If specified, filter to runs with any of the given Versions.
user_ids: list[str] | None
If specified, filter to runs with any of the given User IDs.
Returns
-------
list[Run]
Runs matching the filter criteria
Raises
------
APIException
If api communication fails, request is unauthorized or is unauthenticated.
"""
return [
Run.model_validate(run.model_dump(exclude_none=True))
for run in _apis.training_v2.v2_runs_api.v2_runs_get(
blueprint_id=blueprint_ids,
created_after=_datetime_to_epoch(created_after) if created_after else None,
created_before=_datetime_to_epoch(created_before) if created_before else None,
id=ids,
id_after=id_after,
limit=limit,
name_ilike=name_ilikes,
offset=offset,
project_id=project_ids,
select=select,
sort=sort,
status=statuses,
task_type=task_types,
version=versions,
user_id=user_ids,
).data
]
[docs]
def validate_run_config(
*,
blueprint_id: str,
config: dict,
):
"""Validate a training run configuration against the provided blueprint id.
Parameters
----------
blueprint_id: str
The blueprint to validate against
config: dict
The run configuration to validate
Raises
------
ValidationError
if the provided run config is invalid
APIException
if api communication fails, request is unauthorized or is unauthenticated.
"""
body = training_v2_openapi_models.ValidateBlueprintRequest(
config=config,
)
try:
_apis.training_v2.v2_blueprints_api.v2_blueprints_id_validate_post_with_http_info(
id=blueprint_id,
body=body,
)
except ApiException as e:
if e.status == 422:
field_errors_response_form_dict = FieldErrorsResponse.from_json(e.body)
raise ValidationError(field_errors_response_form_dict.errors) from None
else:
raise e