Source code for chariot.training_v2.checkpoint

import os
import pathlib
import re
from datetime import datetime
from typing import Literal
from urllib.parse import urlencode, urljoin

import urllib3
from pydantic import StrictInt, StrictStr, Field

from chariot import _apis, mcp_setting
from chariot.training_v2._common import (
    BaseModelWithDatetime,
    _datetime_to_epoch,
    _get_epoch_validator,
)
from chariot_api._openapi.training_v2 import models as training_v2_openapi_models

__all__ = [
    "Checkpoint",
    "get_checkpoints",
    "delete_checkpoints",
    "create_model_from_checkpoint",
    "download_checkpoint",
]


[docs] class Checkpoint(BaseModelWithDatetime): id: StrictStr | None = None run_id: StrictStr | None = None global_step: StrictInt | None = None project_id: StrictStr | None = None created_at: datetime | None = Field(None, json_schema_extra={"type": "integer"}) _created_at_validator = _get_epoch_validator("created_at") status: StrictStr | None = None status_updated_at: datetime | None = Field(None, json_schema_extra={"type": "integer"}) _status_updated_at_validator = _get_epoch_validator("status_updated_at") bucket_name: StrictStr | None = None key_prefix: StrictStr | None = None
[docs] def create_model( self, *, name: str, version: str, summary: str, project_id: str | None = None, ) -> str: """Create a model from this checkpoint Parameters ---------- name : str The name to give the model version : str The version to give the model. Must be in SemVer format summary : str A short summary of the model project_id : Optional[str] The ID of the project to create the model in. If omitted, the project ID of the associated run will be used. Returns ------- model_id : str The ID of the created model Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ if not self.id: raise ValueError("id must be set on checkpoint in order to create model") return create_model_from_checkpoint( checkpoint_id=self.id, name=name, version=version, summary=summary, project_id=project_id, )
[docs] def get_checkpoints( *, ids: list[str] | None = None, run_ids: list[str] | None = None, project_ids: list[str] | None = None, global_steps: list[int] | None = None, statuses: list[Literal["incomplete", "complete"]] | None = None, created_before: datetime | None = None, created_after: datetime | None = None, select: list[ Literal[ "id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at" ] ] | None = None, sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None, limit: int | None = None, offset: int | None = None, ) -> list[Checkpoint]: """Get checkpoints matching the provided filters Parameters ---------- ids : Optional[List[str]] If specified, filter to checkpoints with any of the given IDs run_ids : Optional[List[str]] If specified, filter to checkpoints for any of the given Run Ids project_ids : Optional[List[str]] If specified, filter to checkpoints with any of the given Project Ids global_steps : Optional[List[int]] If specified, filter to checkpoints from any of the given Global Steps statuses : Optional[List[Literal["incomplete", "complete"]]] If specified, filter to checkpoints with a specific status. created_before : Optional[datetime] If specified, filter to checkpoints created before the given date and time. This can be used for keyset pagination created_after : Optional[datetime] If specified, filter to checkpoints created after the given date and time. This can be used for keyset pagination select : Optional[List[Literal["id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at"]]] If specified, only the given fields are included in the response. sort : Optional[List[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]]] Sort by the given fields in the given directions. Default: `"created_at:desc"` limit : Optional[int] Limit the response to the given number of checkpoints. Default: 10 offset : Optional[int] Offset based pagination. Default: 0 Returns ------- checkpoints : List[Checkpoint] The checkpoints matching the filter criteria Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return [ Checkpoint.model_validate(checkpoint.model_dump(exclude_none=True)) for checkpoint in _apis.training_v2.v2_checkpoints_api.v2_checkpoints_get( created_after=_datetime_to_epoch(created_after) if created_after else None, created_before=_datetime_to_epoch(created_before) if created_before else None, global_step=global_steps, id=ids, limit=limit, offset=offset, project_id=project_ids, run_id=run_ids, select=select, sort=sort or ["created_at:desc"], status=statuses, ).data ]
[docs] @mcp_setting(mutating=True) def delete_checkpoints( *, ids: list[str] | None = None, run_ids: list[str] | None = None, ) -> None: """Delete checkpoints matching the provided filters Parameters ---------- id : Optional[List[str]] If specified, filter to checkpoints with any of the given Checkpoint IDs. Note: either `id` or `run_id` must be specified, in order to prevent accidental deletion of all checkpoints. run_id : Optional[List[str]] If specified, filter to checkpoints with any of the given Run IDs. Note: either `id` or `run_id` must be specified, in order to prevent accidental deletion of all checkpoints. Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ _apis.training_v2.v2_checkpoints_api.v2_checkpoints_delete(id=ids, run_id=run_ids)
[docs] @mcp_setting(file_based=True) def download_checkpoint(id: str, file_dir: str) -> None: """Download checkpoint artifacts Parameters ---------- id : str The ID of the checkpoint to download file_dir : str The file dir for the downloaded checkpoint artifacts, file dir must exist. Returns ------- Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ data = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_get(id).data checkpoint = Checkpoint.model_validate(data.model_dump(exclude_none=True)) s3_info = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_s3_endpoint_get().data prefix_query_param = urlencode({"prefix": checkpoint.key_prefix}) url = urljoin( s3_info.endpoint_url, f"{checkpoint.bucket_name}?list-type=2&{prefix_query_param}" ) signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post( body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url) ).data headers = {"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD"} headers["Authorization"] = signature.authorization headers["X-Amz-Date"] = signature.x_amz_date pool_manager = urllib3.PoolManager() resp = pool_manager.request("GET", url, headers=headers, preload_content=False) if resp.status != 200: raise ValueError("failed to list checkpoint files") klist = re.findall(r"\<Key>(.*?)\</Key>", str(resp.data)) for k in klist: url = urljoin(s3_info.endpoint_url, f"{checkpoint.bucket_name}/{k}") signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post( body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url) ).data headers["Authorization"] = signature.authorization headers["X-Amz-Date"] = signature.x_amz_date resp = pool_manager.request("GET", url, headers=headers, preload_content=False) if resp.status != 200: raise ValueError("failed to dowload checkpoint file {k}") fpath = file_dir / pathlib.Path(os.path.relpath(k, checkpoint.key_prefix)) os.makedirs(os.path.dirname(fpath), exist_ok=True) with open(fpath, "w+b") as out: while True: data = resp.read(1000) if not data: break out.write(data)
[docs] @mcp_setting(mutating=True) def create_model_from_checkpoint( *, checkpoint_id: str, name: str, version: str, summary: str, project_id: str | None = None, ) -> str: """Create a model from a checkpoint Parameters ---------- checkpoint_id : str The ID of the checkpoint to create the model from name : str The name to give the model version : str The version to give the model. Must be in SemVer format summary : str A short summary of the model project_id : Optional[str] The ID of the project to create the model in. If omitted, the project ID of the associated run will be used. Returns ------- model_id : str The ID of the created model Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_create_model_post( id=checkpoint_id, body=training_v2_openapi_models.CreateModelFromCheckpointRequest( name=name, version=version, summary=summary, project_id=project_id, ), ).data.model_id