import os
import tempfile
import time
import zipfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import joblib
import urllib3
from chariot import _apis
from chariot.config import getLogger
from chariot.models.enum import ArtifactType, TaskType
from chariot.models.model import Model
from chariot.projects.projects import _get_project_id_from_project_args
from chariot_api._openapi.models.models import (
InputModel,
InputSignS3Request,
ModelsUploadsCompleteMultipartPostRequest,
ModelsUploadsCreateMultipartPostRequest,
ModelsUploadsCreateMultipartPostResponseData,
ModelsUploadsIdGetResponse,
ModelsUploadsIdGetResponseData,
)
__all__ = ["import_model"]
log = getLogger(__name__)
@contextmanager
def _temporary_zip_dir(src_dir: str):
"""Zip a directory and return the path to the ZIP.
Context manager that zips `src_dir` (excluding dot‐files/dirs),
yields the path to the temp ZIP, then removes it on exit.
"""
# create a temp file for the ZIP
tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
tmp.close()
try:
# write the ZIP
with zipfile.ZipFile(tmp.name, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(src_dir):
# skip sub-dirs starting with '.'
dirs[:] = [d for d in dirs if not d.startswith(".")]
for fname in files:
if fname.startswith("."):
continue
full_path = os.path.join(root, fname)
rel_path = os.path.relpath(full_path, start=src_dir)
zf.write(full_path, arcname=rel_path)
yield tmp.name
finally:
try:
os.remove(tmp.name)
except OSError:
pass
[docs]
@_apis.login_required
def import_model(
*,
name: str,
version: str,
summary: str,
artifact_type: str | ArtifactType,
task_type: str | TaskType,
project_id: str | None = None,
project_name: str | None = None,
subproject_name: str | None = None,
organization_id: str | None = None,
class_labels: dict[str, int] | None = None,
model_object: Any | None = None,
input_info: Any | None = None, # input_info only used for sklearn models?
model_path: str | None = None,
use_internal_url: bool = False,
) -> Model:
"""Import a local model into Chariot.
For a previously exported Chariot model, model_path is the local path to the gzipped tar.
For a Huggingface model, model_path is either the local path to the directory or
the path to a zip file containing all of the model files.
"""
project_id = _get_project_id_from_project_args(
project_id=project_id,
project_name=project_name,
subproject_name=subproject_name,
organization_id=organization_id,
)
if (model_path and model_object) or (not model_path and not model_object):
raise ValueError("Must supply model_path or model_object but not both")
if class_labels is not None and not isinstance(class_labels, dict):
raise ValueError("class_labels must be a dict")
if not model_object:
if model_path is None or not os.path.exists(model_path):
raise ValueError(f"model_path {model_path!r} does not exist")
if not artifact_type:
raise ValueError(f"Must supply artifact_type: {ArtifactType.values()}")
if isinstance(artifact_type, str):
artifact_type = ArtifactType.get(artifact_type)
if isinstance(task_type, str):
task_type = TaskType.get(task_type)
if model_object is not None:
if artifact_type != ArtifactType.SKLEARN:
raise ValueError(
f"Unsupported model_object artifact_type: {artifact_type}. Import via model_path instead."
)
tf = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
joblib.dump(model_object, tf.name)
model_path = tf.name
assert model_path is not None
if os.path.isdir(model_path):
# NOTE(s.maddox): This code path is really only meant for Huggingface models,
# but there's no fundamental reason it can't be used for other model types.
with _temporary_zip_dir(model_path) as zip_path:
return _upload_and_create_model(
model_file_path=zip_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
return _upload_and_create_model(
model_file_path=model_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
def _upload_and_create_model(
*,
model_file_path: str,
project_id: str,
name: str,
version: str,
summary: str,
artifact_type: str,
task_type: str,
class_labels: dict[str, int] | None,
input_info: Any | None = None, # input_info only used for sklearn models?
use_internal_url: bool,
):
# Create the upload
create_upload_data: ModelsUploadsCreateMultipartPostResponseData = (
_apis.models.models_api.models_uploads_create_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCreateMultipartPostRequest(
file_name=os.path.basename(model_file_path),
file_size=os.path.getsize(model_file_path),
project_id=project_id,
use_internal_url=use_internal_url,
),
)
).data
# create_upload_data: ModelsUploadsCreateMultipartPostResponseData = create_upload_response.data # pyright: ignore [reportAssignmentType]
upload_id = create_upload_data.upload_id
log.debug(f"Upload ID: {upload_id}")
part_size = create_upload_data.part_size
signed_urls = create_upload_data.signed_urls
assert upload_id is not None
assert part_size is not None
assert signed_urls is not None
# Upload the model archive in parts
pool_manager = urllib3.PoolManager()
etags: list[str] = []
with open(model_file_path, "rb") as fp:
for idx, url in enumerate(signed_urls, start=1):
part = fp.read(part_size)
if not part:
break # no more data
resp = pool_manager.request(
"PUT",
url,
body=part,
# headers={'Content-Length': str(len(part))}
)
if resp.status != 200:
# TODO(s.maddox): retry with backoff?
raise RuntimeError(f"Upload failed for part {idx}: HTTP {resp.status}")
# grab the ETag (case‐insensitive)
etag = resp.headers.get("etag")
if etag is None:
raise RuntimeError(f"No ETag returned for part {idx}")
etags.append(etag)
# Complete the upload
_apis.models.models_api.models_uploads_complete_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCompleteMultipartPostRequest(
upload_id=upload_id,
etags=etags,
),
)
# Poll the upload until it's `extracted` or `extraction_error`, or until there are no progress updates for 120 seconds.
while True:
upload_response: ModelsUploadsIdGetResponse
upload_response = _apis.models.models_api.models_uploads_id_get(id=upload_id) # pyright: ignore [reportAttributeAccessIssue]
assert upload_response.data
upload: ModelsUploadsIdGetResponseData = upload_response.data
if upload.status == "extracted":
log.debug(f"Upload {upload_id} extracted paths: {upload.extracted_paths}")
break
if upload.status == "extraction_error":
raise RuntimeError(f"upload extraction error: {upload.extraction_error}")
assert upload.updated_at is not None
updated_at = datetime.fromtimestamp(upload.updated_at / 1000.0, tz=UTC)
if updated_at < datetime.now(tz=UTC) - timedelta(seconds=120):
raise RuntimeError(
f"Upload {upload_id} has not been updated for 120 seconds. "
f"Current status: {upload.status!r}"
)
log.debug(
f"Upload {upload_id} status: {upload.status!r}, extracted bytes: {upload.extracted_bytes} / {upload.total_bytes}"
)
time.sleep(2) # wait before polling again
metadata = upload.extracted_metadata
if metadata:
if metadata.artifact_type and metadata.artifact_type != artifact_type:
log.warning(
f"Specified artifact type, {artifact_type!r}, does not match the type from the archive, {metadata.artifact_type!r}."
f" Using the specified artifact type, {artifact_type!r}."
)
if metadata.task_type and metadata.task_type != task_type:
log.warning(
f"Specified task type, {task_type!r}, does not match the type from the archive, {metadata.task_type!r}."
f" Using the specified task type, {task_type!r}."
)
if class_labels is None and metadata.class_labels is not None:
# NOTE(s.maddox): Presumably, if the model archive contains class labels
# and the user didn't specify an override, then we should use the
# class labels from the archive.
class_labels = metadata.class_labels
# Create the model from the upload
model_data = _apis.models.models_api.projects_project_models_post( # pyright: ignore [reportAttributeAccessIssue]
project=project_id,
body=InputModel(
from_upload_id=upload_id,
name=name,
version=str(version),
summary=summary,
task_type=task_type,
artifact_type=artifact_type,
class_labels=class_labels,
model_config=metadata.pymodel_config if metadata else None,
input_info=input_info,
),
).data
return Model(metadata=model_data, start_server=False)
@dataclass
class S3Signature:
authorization: str
x_amz_date: str
@_apis.login_required
def sign_s3_request(url: str, method: str, headers: dict[str, str] | None = None) -> S3Signature:
"""Sign an S3 request for model download."""
data = _apis.models.models_api.sign_s3_request_post(
InputSignS3Request(url=url, method=method, headers=headers)
)
return S3Signature(authorization=data.authorization, x_amz_date=data.x_amz_date)
@_apis.login_required
def get_s3_endpoint() -> str:
"""Get S3 endpoint for model download."""
data = _apis.models.models_api.s3_endpoint_get()
return data.endpoint_url