import os
import tempfile
import time
import zipfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import joblib
import urllib3
from chariot import _apis, mcp_setting
from chariot.config import getLogger
from chariot.models._utils import _upload_parts
from chariot.models.enum import ArtifactType, TaskType
from chariot.models.model import Model, get_model_by_id, join_url
from chariot.projects.projects import _get_project_id_from_project_args
from chariot_api._openapi.models.models import (
InputModel,
InputSignS3Request,
ModelsUploadsCompleteMultipartPostRequest,
ModelsUploadsCreateMultipartPostRequest,
ModelsUploadsCreateMultipartPostResponseData,
ModelsUploadsIdGetResponse,
ModelsUploadsIdGetResponseData,
)
__all__ = [
"import_model",
"upload_model_file",
]
log = getLogger(__name__)
@contextmanager
def _temporary_zip_dir(src_dir: str):
"""Zip a directory and return the path to the ZIP.
Context manager that zips `src_dir` (excluding dotāfiles/dirs),
yields the path to the temp ZIP, then removes it on exit.
"""
# create a temp file for the ZIP
tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
tmp.close()
try:
# write the ZIP
with zipfile.ZipFile(tmp.name, mode="w", compression=zipfile.ZIP_STORED) as zf:
for root, dirs, files in os.walk(src_dir):
# skip sub-dirs starting with '.'
dirs[:] = [d for d in dirs if not d.startswith(".")]
for fname in files:
if fname.startswith("."):
continue
full_path = os.path.join(root, fname)
rel_path = os.path.relpath(full_path, start=src_dir)
zf.write(full_path, arcname=rel_path)
yield tmp.name
finally:
try:
os.remove(tmp.name)
except OSError:
pass
[docs]
@mcp_setting(file_based=True, mutating=True)
@_apis.login_required
def import_model(
*,
name: str,
version: str,
summary: str,
artifact_type: str | ArtifactType,
task_type: str | TaskType,
project_id: str | None = None,
project_name: str | None = None,
subproject_name: str | None = None,
organization_id: str | None = None,
class_labels: dict[str, int] | None = None,
model_object: Any | None = None,
input_info: Any | None = None, # input_info only used for sklearn models?
model_path: str | None = None,
use_internal_url: bool = False,
) -> Model:
"""Import a local model into Chariot.
For a previously exported Chariot model, model_path is the local path to the gzipped tar.
For a Hugging Face model, model_path is either the local path to the directory or
the path to a zip file containing all of the model files.
:param name: The name of the model.
:param version: The version of the model.
:param summary: The summary of the model.
:param artifact_type: The type of the model.
:param task_type: The model's functionality.
:param project_id: The ID of the project.
:param project_name: The name of the project.
:param organization_id: The ID of the organization.
:param class_labels: The class labels of the model.
:param model_object: The model object.
:param input_info: The input info of the model.
:param model_path: The path to the model.
:param use_internal_url: Whether to use the internal URL.
:return: The imported model.
:rtype: Model
:raises ValueError: If the model_path or model_object is not provided.
:raises ValueError: If the class_labels is not a dict.
:raises ValueError: If the artifact_type is not provided.
:raises ValueError: If the model_path does not exist.
:raises ValueError: If the model_object is not a model object.
:raises ValueError: If the input_info is not a dict.
:raises urllib3.exceptions.HTTPError: If the upload exceeds the maximum number of attempts.
:raises RuntimeError: If the upload extraction fails.
:raises RuntimeError: If the upload times out.
"""
project_id = _get_project_id_from_project_args(
project_id=project_id,
project_name=project_name,
subproject_name=subproject_name,
organization_id=organization_id,
)
if (model_path and model_object) or (not model_path and not model_object):
raise ValueError("Must supply model_path or model_object but not both")
if class_labels is not None and not isinstance(class_labels, dict):
raise ValueError("class_labels must be a dict")
if not model_object:
if model_path is None or not os.path.exists(model_path):
raise ValueError(f"model_path {model_path!r} does not exist")
if not artifact_type:
raise ValueError(f"Must supply artifact_type: {ArtifactType.values()}")
models = _apis.models.models_api.models_get( # type: ignore [reportAttributeAccessIssue]
project_ids=project_id,
model_name=name,
model_version=version,
).data
if len(models) > 0:
raise ValueError(
f"Model with name {name!r} and version {version!r} already exists in project {project_id!r}."
)
if isinstance(artifact_type, str):
artifact_type = ArtifactType.get(artifact_type)
if isinstance(task_type, str):
task_type = TaskType.get(task_type)
if model_object is not None:
if artifact_type != ArtifactType.SKLEARN:
raise ValueError(
f"Unsupported model_object artifact_type: {artifact_type}. Import via model_path instead."
)
tf = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
joblib.dump(model_object, tf.name)
model_path = tf.name
assert model_path is not None
if os.path.isdir(model_path):
# NOTE(s.maddox): This code path is really only meant for Huggingface models,
# but there's no fundamental reason it can't be used for other model types.
with _temporary_zip_dir(model_path) as zip_path:
return _upload_and_create_model(
model_file_path=zip_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
return _upload_and_create_model(
model_file_path=model_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
def _upload_and_create_model(
*,
model_file_path: str,
project_id: str,
name: str,
version: str,
summary: str,
artifact_type: str,
task_type: str,
class_labels: dict[str, int] | None,
input_info: Any | None = None, # input_info only used for sklearn models?
use_internal_url: bool,
) -> Model:
"""Upload a model file to the Chariot server and create a model from the upload."""
# Create the upload
create_upload_data: ModelsUploadsCreateMultipartPostResponseData = (
_apis.models.models_api.models_uploads_create_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCreateMultipartPostRequest(
file_name=os.path.basename(model_file_path),
file_size=os.path.getsize(model_file_path),
project_id=project_id,
use_internal_url=use_internal_url,
),
)
).data
upload_id = create_upload_data.upload_id
log.debug(f"Upload ID: {upload_id}")
part_size = create_upload_data.part_size
signed_urls = create_upload_data.signed_urls
assert upload_id is not None
assert part_size is not None
assert signed_urls is not None
# Upload the model archive in parts
etags = _upload_parts(model_file_path, signed_urls, part_size)
# Complete the upload
max_attempts = 10
for attempt in range(1, max_attempts + 1):
try:
_apis.models.models_api.models_uploads_complete_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCompleteMultipartPostRequest(
upload_id=upload_id,
etags=etags,
),
)
break
except urllib3.exceptions.HTTPError as ex:
if attempt == max_attempts:
raise ex
log.warning(
f"Retrying complete-multipart (attempt {attempt}/{max_attempts}) due to {ex}"
)
continue
# Poll the upload until it's `extracted` or `extraction_error`, or until there are no progress updates for 120 seconds.
while True:
upload_response: ModelsUploadsIdGetResponse
upload_response = _apis.models.models_api.models_uploads_id_get(id=upload_id) # pyright: ignore [reportAttributeAccessIssue]
assert upload_response.data
upload: ModelsUploadsIdGetResponseData = upload_response.data
if upload.status == "extracted":
log.debug(f"Upload {upload_id} extracted paths: {upload.extracted_paths}")
break
if upload.status == "extraction_error":
raise RuntimeError(f"upload extraction error: {upload.extraction_error}")
assert upload.updated_at is not None
updated_at = datetime.fromtimestamp(upload.updated_at / 1000.0, tz=UTC)
if updated_at < datetime.now(tz=UTC) - timedelta(seconds=120):
raise RuntimeError(
f"Upload {upload_id} has not been updated for 120 seconds. "
f"Current status: {upload.status!r}"
)
log.debug(
f"Upload {upload_id} status: {upload.status!r}, extracted bytes: {upload.extracted_bytes} / {upload.total_bytes}"
)
time.sleep(2) # wait before polling again
metadata = upload.extracted_metadata
if metadata:
if metadata.artifact_type and metadata.artifact_type != artifact_type:
log.warning(
f"Specified artifact type, {artifact_type!r}, does not match the type from the archive, {metadata.artifact_type!r}."
f" Using the specified artifact type, {artifact_type!r}."
)
if metadata.task_type and metadata.task_type != task_type:
log.warning(
f"Specified task type, {task_type!r}, does not match the type from the archive, {metadata.task_type!r}."
f" Using the specified task type, {task_type!r}."
)
if class_labels is None and metadata.class_labels is not None:
# NOTE(s.maddox): Presumably, if the model archive contains class labels
# and the user didn't specify an override, then we should use the
# class labels from the archive.
class_labels = metadata.class_labels
# Create the model from the upload
model_data = _apis.models.models_api.projects_project_models_post( # pyright: ignore [reportAttributeAccessIssue]
project=project_id,
body=InputModel(
from_upload_id=upload_id,
name=name,
version=str(version),
summary=summary,
task_type=task_type,
artifact_type=artifact_type,
class_labels=class_labels,
model_config=metadata.pymodel_config if metadata else None,
input_info=input_info,
),
).data
return Model(metadata=model_data, start_server=False)
@dataclass
class S3Signature:
authorization: str
x_amz_date: str
@_apis.login_required
def sign_s3_request(url: str, method: str, headers: dict[str, str] | None = None) -> S3Signature:
"""Sign an S3 request for model download."""
data = _apis.models.models_api.sign_s3_request_post(
InputSignS3Request(url=url, method=method, headers=headers)
)
return S3Signature(authorization=data.authorization, x_amz_date=data.x_amz_date)
@_apis.login_required
def get_s3_endpoint() -> str:
"""Get S3 endpoint for model download."""
data = _apis.models.models_api.s3_endpoint_get()
return data.endpoint_url
[docs]
@mcp_setting(file_based=True, mutating=True)
@_apis.login_required
def upload_model_file(model_id: str, filename: str, data: bytes) -> str:
"""Upload a file to the model.
NOTE: Currently only supports upload of README.md and model drift detectors.
:param model_id: The model where the detector will be stored.
:type model_id: str
:param detector: The detector to upload.
:type detector: Union[SemanticDriftDetector, BatchDriftDetector]
:return: The storage url.
:rtype: str
"""
model = get_model_by_id(model_id)
data_length = len(data)
# POST - get pre-signed URL and upload ID
request_response = _apis.models.models_api.models_id_files_upload_request_get(
id=model_id,
path=filename,
filesize=data_length,
multipart=False,
)
if len(request_response) == 0:
raise RuntimeError(f"Upload request failed")
presigned_url = request_response[0].url
if presigned_url is None:
raise RuntimeError(f"No presigned URL returned for file {filename}")
method = request_response[0].method
if method is None:
raise RuntimeError(f"Upload request did not return a method")
# PUT - upload file bytes, get ETag
http = urllib3.PoolManager()
headers = {"Content-Length": str(data_length)}
upload_response = http.request(
method,
presigned_url,
body=data,
headers=headers,
)
if upload_response.status != 200:
raise RuntimeError(f"failed to upload file with response: {upload_response}")
return join_url(model._meta.uri, filename)