import os
import pathlib
import re
from datetime import datetime
from typing import Literal
from urllib.parse import urlencode, urljoin
import urllib3
from pydantic import StrictInt, StrictStr
from chariot import _apis
from chariot.training_v2._common import (
BaseModelWithDatetime,
_datetime_to_epoch,
_get_epoch_validator,
)
from chariot_api._openapi.training_v2 import models as training_v2_openapi_models
__all__ = [
"Checkpoint",
"get_checkpoints",
"delete_checkpoints",
"create_model_from_checkpoint",
"download_checkpoint",
]
[docs]
class Checkpoint(BaseModelWithDatetime):
id: StrictStr | None = None
run_id: StrictStr | None = None
global_step: StrictInt | None = None
project_id: StrictStr | None = None
created_at: datetime | None = None
_created_at_validator = _get_epoch_validator("created_at")
status: StrictStr | None = None
status_updated_at: datetime | None = None
_status_updated_at_validator = _get_epoch_validator("status_updated_at")
bucket_name: StrictStr | None = None
key_prefix: StrictStr | None = None
[docs]
def create_model(
self,
*,
name: str,
version: str,
summary: str,
project_id: str | None = None,
) -> str:
"""Create a model from this checkpoint
Parameters
----------
name : str
The name to give the model
version : str
The version to give the model. Must be in SemVer format
summary : str
A short summary of the model
project_id : Optional[str]
The ID of the project to create the model in.
If omitted, the project ID of the associated run will be used.
Returns
-------
model_id : str
The ID of the created model
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
if not self.id:
raise ValueError("id must be set on checkpoint in order to create model")
return create_model_from_checkpoint(
checkpoint_id=self.id,
name=name,
version=version,
summary=summary,
project_id=project_id,
)
[docs]
def get_checkpoints(
*,
ids: list[str] | None = None,
run_ids: list[str] | None = None,
project_ids: list[str] | None = None,
global_steps: list[int] | None = None,
statuses: list[Literal["incomplete", "complete"]] | None = None,
created_before: datetime | None = None,
created_after: datetime | None = None,
select: list[
Literal[
"id", "created_at", "run_id", "project_id", "global_step", "status", "status_updated_at"
]
]
| None = None,
sort: list[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[Checkpoint]:
"""Get checkpoints matching the provided filters
Parameters
----------
ids : Optional[List[str]]
If specified, filter to checkpoints with any of the given IDs
run_ids : Optional[List[str]]
If specified, filter to checkpoints for any of the given Run Ids
project_ids : Optional[List[str]]
If specified, filter to checkpoints with any of the given Project Ids
global_steps : Optional[List[int]]
If specified, filter to checkpoints from any of the given
Global Steps
statuses : Optional[List[Literal["incomplete", "complete"]]]
If specified, filter to checkpoints with a specific status.
created_before : Optional[datetime]
If specified, filter to checkpoints created before the
given date and time. This can be used for keyset pagination
created_after : Optional[datetime]
If specified, filter to checkpoints created after the
given date and time. This can be used for keyset pagination
select : Optional[List[Literal["id", "created_at", "run_id",
"project_id", "global_step", "status", "status_updated_at"]]]
If specified, only the given fields are included in the response.
sort : Optional[List[Literal["id:asc", "id:desc", "created_at:asc", "created_at:desc"]]]
Sort by the given fields in the given directions.
Default: `"created_at:desc"`
limit : Optional[int]
Limit the response to the given number of checkpoints. Default: 10
offset : Optional[int]
Offset based pagination. Default: 0
Returns
-------
checkpoints : List[Checkpoint]
The checkpoints matching the filter criteria
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
return [
Checkpoint.model_validate(checkpoint.model_dump(exclude_none=True))
for checkpoint in _apis.training_v2.v2_checkpoints_api.v2_checkpoints_get(
created_after=_datetime_to_epoch(created_after) if created_after else None,
created_before=_datetime_to_epoch(created_before) if created_before else None,
global_step=global_steps,
id=ids,
limit=limit,
offset=offset,
project_id=project_ids,
run_id=run_ids,
select=select,
sort=sort or ["created_at:desc"],
status=statuses,
).data
]
[docs]
def delete_checkpoints(
*,
ids: list[str] | None = None,
run_ids: list[str] | None = None,
) -> None:
"""Delete checkpoints matching the provided filters
Parameters
----------
id : Optional[List[str]]
If specified, filter to checkpoints with any of the given Checkpoint IDs.
Note: either `id` or `run_id` must be specified, in order to prevent
accidental deletion of all checkpoints.
run_id : Optional[List[str]]
If specified, filter to checkpoints with any of the given Run IDs.
Note: either `id` or `run_id` must be specified, in order to prevent
accidental deletion of all checkpoints.
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
_apis.training_v2.v2_checkpoints_api.v2_checkpoints_delete(id=ids, run_id=run_ids)
[docs]
def download_checkpoint(id: str, file_dir: str) -> None:
"""Download checkpoint artifacts
Parameters
----------
id : str
The ID of the checkpoint to download
file_dir : str
The file dir for the downloaded checkpoint artifacts, file dir must exist.
Returns
-------
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
data = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_get(id).data
checkpoint = Checkpoint.model_validate(data.model_dump(exclude_none=True))
s3_info = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_s3_endpoint_get().data
prefix_query_param = urlencode({"prefix": checkpoint.key_prefix})
url = urljoin(
s3_info.endpoint_url, f"{checkpoint.bucket_name}?list-type=2&{prefix_query_param}"
)
signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post(
body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url)
).data
headers = {"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD"}
headers["Authorization"] = signature.authorization
headers["X-Amz-Date"] = signature.x_amz_date
pool_manager = urllib3.PoolManager()
resp = pool_manager.request("GET", url, headers=headers, preload_content=False)
if resp.status != 200:
raise ValueError("failed to list checkpoint files")
klist = re.findall(r"\<Key>(.*?)\</Key>", str(resp.data))
for k in klist:
url = urljoin(s3_info.endpoint_url, f"{checkpoint.bucket_name}/{k}")
signature = _apis.training_v2.v2_checkpoints_api.v2_checkpoints_sign_s3_request_post(
body=training_v2_openapi_models.SignS3RequestRequest(method="GET", url=url)
).data
headers["Authorization"] = signature.authorization
headers["X-Amz-Date"] = signature.x_amz_date
resp = pool_manager.request("GET", url, headers=headers, preload_content=False)
if resp.status != 200:
raise ValueError("failed to dowload checkpoint file {k}")
fpath = file_dir / pathlib.Path(os.path.relpath(k, checkpoint.key_prefix))
os.makedirs(os.path.dirname(fpath), exist_ok=True)
with open(fpath, "w+b") as out:
while True:
data = resp.read(1000)
if not data:
break
out.write(data)
[docs]
def create_model_from_checkpoint(
*,
checkpoint_id: str,
name: str,
version: str,
summary: str,
project_id: str | None = None,
) -> str:
"""Create a model from a checkpoint
Parameters
----------
checkpoint_id : str
The ID of the checkpoint to create the model from
name : str
The name to give the model
version : str
The version to give the model. Must be in SemVer format
summary : str
A short summary of the model
project_id : Optional[str]
The ID of the project to create the model in.
If omitted, the project ID of the associated run will be used.
Returns
-------
model_id : str
The ID of the created model
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
return _apis.training_v2.v2_checkpoints_api.v2_checkpoints_id_create_model_post(
id=checkpoint_id,
body=training_v2_openapi_models.CreateModelFromCheckpointRequest(
name=name,
version=version,
summary=summary,
project_id=project_id,
),
).data.model_id