Source code for chariot.training_v2.checkpoint

import os
import pathlib
import re
from datetime import datetime
from typing import Literal
from urllib.parse import urlencode, urljoin

import urllib3
from pydantic import StrictInt, StrictStr

from chariot import _apis
from chariot.training_v2._common import (
    BaseModelWithDatetime,
    _datetime_to_epoch,
    _get_epoch_validator,
)
from chariot_api._openapi.training_v2 import models as training_v2_openapi_models

__all__ = [
    "Checkpoint",
    "get_checkpoints",
    "delete_checkpoints",
    "create_model_from_checkpoint",
    "download_checkpoint",
]


[docs] class Checkpoint(BaseModelWithDatetime): id: StrictStr | None = None run_id: StrictStr | None = None global_step: StrictInt | None = None project_id: StrictStr | None = None created_at: datetime | None = None _created_at_validator = _get_epoch_validator("created_at") status: StrictStr | None = None status_updated_at: datetime | None = None _status_updated_at_validator = _get_epoch_validator("status_updated_at") bucket_name: StrictStr | None = None key_prefix: StrictStr | None = None
[docs] def create_model( self, *, name: str, version: str, summary: str, project_id: str | None = None, ) -> str: """Create a model from this checkpoint Parameters ---------- name : str The name to give the model version : str The version to give the model. Must be in SemVer format summary : str A short summary of the model project_id : Optional[str] The ID of the project to create the model in. If omitted, the project ID of the associated run will be used. Returns ------- model_id : str The ID of the created model Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ if not self.id: raise ValueError("id must be set on checkpoint in order to create model") return create_model_from_checkpoint( checkpoint_id=self.id, name=name, version=version, summary=summary, project_id=project_id, )
[docs] def get_checkpoints( *, ids: list[str] | None = None, run_ids: list[str] | None = None, project_ids: list[str] | None = None, global_steps: list[int] | None = None, statuses: list[Literal["incomplete", "complete"]] | None = None, created_before: datetime | None = None, created_after: datetime | None = None, select: list[ Literal[ "id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at" ] ] | None = None, sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None, limit: int | None = None, offset: int | None = None, ) -> list[Checkpoint]: """Get checkpoints matching the provided filters Parameters ---------- ids : Optional[List[str]] If specified, filter to checkpoints with any of the given IDs run_ids : Optional[List[str]] If specified, filter to checkpoints for any of the given Run Ids project_ids : Optional[List[str]] If specified, filter to checkpoints with any of the given Project Ids global_steps : Optional[List[int]] If specified, filter to checkpoints from any of the given Global Steps statuses : Optional[List[Literal["incomplete", "complete"]]] If specified, filter to checkpoints with a specific status. created_before : Optional[datetime] If specified, filter to checkpoints created before the given date and time. This can be used for keyset pagination created_after : Optional[datetime] If specified, filter to checkpoints created after the given date and time. This can be used for keyset pagination select : Optional[List[Literal["id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at"]]] If specified, only the given fields are included in the response. sort : Optional[List[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]]] Sort by the given fields in the given directions. Default: `"created_at:desc"` limit : Optional[int] Limit the response to the given number of checkpoints. Default: 10 offset : Optional[int] Offset based pagination. Default: 0 Returns ------- checkpoints : List[Checkpoint] The checkpoints matching the filter criteria Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return [ Checkpoint.model_validate(checkpoint.model_dump(exclude_none=True)) for checkpoint in _apis.training_v2.v2_checkpoints_api.v2_checkpoints_get( created_after=_datetime_to_epoch(created_after) if created_after else None, created_before=_datetime_to_epoch(created_before) if created_before else None, global_step=global_steps, id=ids, limit=limit, offset=offset, project_id=project_ids, run_id=run_ids, select=select, sort=sort or ["created_at:desc"], status=statuses, ).data ]
[docs] def delete_checkpoints( *, ids: list[str] | None = None, run_ids: list[str] | None = None, ) -> None: """Delete checkpoints matching the provided filters Parameters ---------- id : Optional[List[str]] If specified, filter to checkpoints with any of the given Checkpoint IDs. Note: either `id` or `run_id` must be specified, in order to prevent accidental deletion of all checkpoints. run_id : Optional[List[str]] If specified, filter to checkpoints with any of the given Run IDs. Note: either `id` or `run_id` must be specified, in order to prevent accidental deletion of all checkpoints. Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ _apis.training_v2.v2_checkpoints_api.v2_checkpoints_delete(id=ids, run_id=run_ids)
[docs] def download_checkpoint(id: str, file_dir: str) -> None: """Download checkpoint artifacts Parameters ---------- id : str The ID of the checkpoint to download file_dir : str The file dir for the downloaded checkpoint artifacts, file dir must exist. Returns ------- Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ data = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_get(id).data checkpoint = Checkpoint.model_validate(data.model_dump(exclude_none=True)) s3_info = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_s3_endpoint_get().data prefix_query_param = urlencode({"prefix": checkpoint.key_prefix}) url = urljoin( s3_info.endpoint_url, f"{checkpoint.bucket_name}?list-type=2&{prefix_query_param}" ) signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post( body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url) ).data headers = {"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD"} headers["Authorization"] = signature.authorization headers["X-Amz-Date"] = signature.x_amz_date pool_manager = urllib3.PoolManager() resp = pool_manager.request("GET", url, headers=headers, preload_content=False) if resp.status != 200: raise ValueError("failed to list checkpoint files") klist = re.findall(r"\<Key>(.*?)\</Key>", str(resp.data)) for k in klist: url = urljoin(s3_info.endpoint_url, f"{checkpoint.bucket_name}/{k}") signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post( body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url) ).data headers["Authorization"] = signature.authorization headers["X-Amz-Date"] = signature.x_amz_date resp = pool_manager.request("GET", url, headers=headers, preload_content=False) if resp.status != 200: raise ValueError("failed to dowload checkpoint file {k}") fpath = file_dir / pathlib.Path(os.path.relpath(k, checkpoint.key_prefix)) os.makedirs(os.path.dirname(fpath), exist_ok=True) with open(fpath, "w+b") as out: while True: data = resp.read(1000) if not data: break out.write(data)
[docs] def create_model_from_checkpoint( *, checkpoint_id: str, name: str, version: str, summary: str, project_id: str | None = None, ) -> str: """Create a model from a checkpoint Parameters ---------- checkpoint_id : str The ID of the checkpoint to create the model from name : str The name to give the model version : str The version to give the model. Must be in SemVer format summary : str A short summary of the model project_id : Optional[str] The ID of the project to create the model in. If omitted, the project ID of the associated run will be used. Returns ------- model_id : str The ID of the created model Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ return _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_create_model_post( id=checkpoint_id, body=training_v2_openapi_models.CreateModelFromCheckpointRequest( name=name, version=version, summary=summary, project_id=project_id, ), ).data.model_id