import os
import tempfile
import time
import warnings
import zipfile
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any
import joblib
import urllib3
from chariot import _apis, mcp_setting
from chariot.config import getLogger
from chariot.models._utils import _upload_parts
from chariot.models.enum import ArtifactType, TaskType
from chariot.models.isvc_settings import EngineVersionSelector
from chariot.models.model import Model, get_model_by_id, join_url
from chariot.projects.projects import _get_project_id_from_project_args
from chariot_api._openapi.models.models import (
InputModel,
InputSignS3Request,
ModelsUploadsCompleteMultipartPostRequest,
ModelsUploadsCreateMultipartPostRequest,
ModelsUploadsCreateMultipartPostResponseData,
ModelsUploadsIdGetResponse,
ModelsUploadsIdGetResponseData,
)
__all__ = [
"import_model",
"upload_model_file",
]
log = getLogger(__name__)
@contextmanager
def _temporary_zip_dir(src_dir: str):
"""Zip a directory and return the path to the ZIP.
Context manager that zips `src_dir` (excluding dotāfiles/dirs),
yields the path to the temp ZIP, then removes it on exit.
"""
# create a temp file for the ZIP
tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
tmp.close()
try:
# write the ZIP
with zipfile.ZipFile(tmp.name, mode="w", compression=zipfile.ZIP_STORED) as zf:
for root, dirs, files in os.walk(src_dir):
# skip sub-dirs starting with '.'
dirs[:] = [d for d in dirs if not d.startswith(".")]
for fname in files:
if fname.startswith("."):
continue
full_path = os.path.join(root, fname)
rel_path = os.path.relpath(full_path, start=src_dir)
zf.write(full_path, arcname=rel_path)
yield tmp.name
finally:
try:
os.remove(tmp.name)
except OSError:
pass
[docs]
@mcp_setting(file_based=True, mutating=True)
@_apis.login_required
def import_model(
*,
name: str,
version: str,
summary: str = "",
artifact_type: str | ArtifactType,
task_type: str | TaskType,
project_id: str | None = None,
project_name: str | None = None,
subproject_name: str | None = None,
organization_id: str | None = None,
class_labels: dict[str, int] | None = None,
model_object: Any | None = None,
input_info: Any | None = None, # input_info only used for sklearn models?
model_path: str | None = None,
huggingface_model_id: str | None = None,
huggingface_revision: str | None = None,
huggingface_allow_patterns: list[str] | None = None,
huggingface_ignore_patterns: list[str] | None = None,
use_internal_url: bool = False,
) -> Model:
"""Import a local model into Chariot or directly from Hugging Face Hub.
For a previously exported Chariot model, model_path is the local path to the gzipped tar.
For a Hugging Face model, you can either:
- Provide model_path as the local path to a directory or `.zip` file containing model files, OR
- Provide huggingface_model_id to import directly from Hugging Face Hub (e.g., "mistralai/Ministral-3-8B-Instruct-2512")
:param name: The name of the model.
:param version: The version of the model.
:param summary: The summary of the model.
:param artifact_type: The type of the model.
:param task_type: The model's functionality.
:param project_id: The ID of the project.
:param project_name: The name of the project.
:param organization_id: The ID of the organization.
:param class_labels: The class labels of the model.
:param model_object: The model object.
:param input_info: The input info of the model.
:param model_path: The path to the model file or directory.
:param huggingface_model_id: The Hugging Face model ID (e.g., "mistralai/Ministral-3-8B-Instruct-2512").
:param huggingface_revision: The Hugging Face model revision/branch (default: "main").
:param huggingface_allow_patterns: Glob patterns to filter files to import (e.g., ["*.safetensors", "*.json"]).
:param huggingface_ignore_patterns: Glob patterns to exclude files from import (e.g., ["*.bin"]).
:param use_internal_url: Whether to use the internal URL.
:return: The imported model.
:rtype: Model
:raises ValueError: If neither model_path, model_object, nor huggingface_model_id is provided.
:raises ValueError: If both model_path/model_object and huggingface_model_id are provided.
:raises ValueError: If the class_labels is not a dict.
:raises ValueError: If the artifact_type is not provided.
:raises ValueError: If the model_path does not exist.
:raises ValueError: If the model_object is not a model object.
:raises ValueError: If the input_info is not a dict.
:raises urllib3.exceptions.HTTPError: If the upload exceeds the maximum number of attempts.
:raises RuntimeError: If the upload extraction fails.
:raises RuntimeError: If the upload times out.
"""
project_id = _get_project_id_from_project_args(
project_id=project_id,
project_name=project_name,
subproject_name=subproject_name,
organization_id=organization_id,
)
# Validate mutually exclusive upload sources
upload_sources = sum(
[
bool(model_path),
bool(model_object),
bool(huggingface_model_id),
]
)
if upload_sources == 0:
raise ValueError("Must supply one of: model_path, model_object, or huggingface_model_id")
if upload_sources > 1:
raise ValueError(
"Must supply only one of: model_path, model_object, or huggingface_model_id"
)
if class_labels is not None and not isinstance(class_labels, dict):
raise ValueError("class_labels must be a dict")
if model_path and not os.path.exists(model_path):
raise ValueError(f"model_path {model_path!r} does not exist")
if not artifact_type:
raise ValueError(f"Must supply artifact_type: {ArtifactType.values()}")
models = _apis.models.models_api.models_get( # type: ignore [reportAttributeAccessIssue]
project_ids=project_id,
model_name=name,
model_version=version,
).data
if len(models) > 0:
raise ValueError(
f"Model with name {name!r} and version {version!r} already exists in project {project_id!r}."
)
if isinstance(artifact_type, str):
artifact_type = ArtifactType.get(artifact_type)
if isinstance(task_type, str):
task_type = TaskType.get(task_type)
if huggingface_model_id is not None:
return _import_from_huggingface_hub_and_create_model(
huggingface_model_id=huggingface_model_id,
huggingface_revision=huggingface_revision,
huggingface_allow_patterns=huggingface_allow_patterns,
huggingface_ignore_patterns=huggingface_ignore_patterns,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
)
if model_object is not None:
if artifact_type != ArtifactType.SKLEARN:
raise ValueError(
f"Unsupported model_object artifact_type: {artifact_type}. Import via model_path instead."
)
tf = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
joblib.dump(model_object, tf.name)
model_path = tf.name
assert model_path is not None
if os.path.isdir(model_path):
# NOTE(s.maddox): This code path is really only meant for Huggingface models,
# but there's no fundamental reason it can't be used for other model types.
with _temporary_zip_dir(model_path) as zip_path:
return _upload_and_create_model(
model_file_path=zip_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
return _upload_and_create_model(
model_file_path=model_path,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type.value,
task_type=task_type.value,
class_labels=class_labels,
input_info=input_info,
use_internal_url=use_internal_url,
)
def _upload_and_create_model(
*,
model_file_path: str,
project_id: str,
name: str,
version: str,
summary: str = "",
artifact_type: str,
task_type: str,
class_labels: dict[str, int] | None,
input_info: Any | None = None, # input_info only used for sklearn models?
use_internal_url: bool,
) -> Model:
"""Upload a model file to the Chariot server and create a model from the upload."""
# Create the upload
create_upload_data: ModelsUploadsCreateMultipartPostResponseData = (
_apis.models.models_api.models_uploads_create_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCreateMultipartPostRequest(
file_name=os.path.basename(model_file_path),
file_size=os.path.getsize(model_file_path),
project_id=project_id,
use_internal_url=use_internal_url,
),
)
).data
upload_id = create_upload_data.upload_id
log.debug(f"Upload ID: {upload_id}")
part_size = create_upload_data.part_size
signed_urls = create_upload_data.signed_urls
assert upload_id is not None
assert part_size is not None
assert signed_urls is not None
# Upload the model archive in parts
etags = _upload_parts(model_file_path, signed_urls, part_size)
# Complete the upload
max_attempts = 10
for attempt in range(1, max_attempts + 1):
try:
_apis.models.models_api.models_uploads_complete_multipart_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsCompleteMultipartPostRequest(
upload_id=upload_id,
etags=etags,
),
)
break
except urllib3.exceptions.HTTPError as ex:
if attempt == max_attempts:
raise ex
log.warning(
f"Retrying complete-multipart (attempt {attempt}/{max_attempts}) due to {ex}"
)
continue
return _poll_upload_and_create_model(
upload_id=upload_id,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type,
task_type=task_type,
class_labels=class_labels,
input_info=input_info,
)
def _poll_upload_and_create_model(
*,
upload_id: str,
project_id: str,
name: str,
version: str,
summary: str = "",
artifact_type: str,
task_type: str,
class_labels: dict[str, int] | None,
input_info: Any | None = None,
) -> Model:
"""Poll an upload until extracted, then create the model from it.
Args:
upload_id: The upload ID to poll
project_id: The project ID
name: Model name
version: Model version
summary: Model summary
artifact_type: Artifact type (e.g., "HUGGINGFACE")
task_type: Task type (e.g., "TEXT_CLASSIFICATION")
class_labels: Optional class labels
input_info: Optional input info
Returns:
The created Model
"""
# Poll the upload until it's `extracted` or `extraction_error`, or until there are no progress updates for 120 seconds.
while True:
upload_response: ModelsUploadsIdGetResponse
upload_response = _apis.models.models_api.models_uploads_id_get(id=upload_id) # pyright: ignore [reportAttributeAccessIssue]
assert upload_response.data
upload: ModelsUploadsIdGetResponseData = upload_response.data
if upload.status == "extracted":
log.debug(f"Upload {upload_id} extracted paths: {upload.extracted_paths}")
break
if upload.status == "extraction_error":
raise RuntimeError(f"upload extraction error: {upload.extraction_error}")
assert upload.updated_at is not None
updated_at = datetime.fromtimestamp(upload.updated_at / 1000.0, tz=UTC)
if updated_at < datetime.now(tz=UTC) - timedelta(seconds=120):
raise RuntimeError(
f"Upload {upload_id} has not been updated for 120 seconds. "
f"Current status: {upload.status!r}"
)
log.debug(
f"Upload {upload_id} status: {upload.status!r}, extracted bytes: {upload.extracted_bytes} / {upload.total_bytes}"
)
time.sleep(2) # wait before polling again
metadata = upload.extracted_metadata
if metadata:
if metadata.artifact_type and metadata.artifact_type != artifact_type:
log.warning(
f"Specified artifact type, {artifact_type!r}, does not match the type from the archive, {metadata.artifact_type!r}."
f" Using the specified artifact type, {artifact_type!r}."
)
if metadata.task_type and metadata.task_type != task_type:
log.warning(
f"Specified task type, {task_type!r}, does not match the type from the archive, {metadata.task_type!r}."
f" Using the specified task type, {task_type!r}."
)
if class_labels is None and metadata.class_labels is not None:
# NOTE(s.maddox): Presumably, if the model archive contains class labels
# and the user didn't specify an override, then we should use the
# class labels from the archive.
class_labels = metadata.class_labels
# Create the model from the upload
model_data = _apis.models.models_api.models_post( # pyright: ignore [reportAttributeAccessIssue]
body=InputModel(
from_upload_id=upload_id,
name=name,
version=str(version),
summary=summary,
project_id=project_id,
task_type=task_type,
artifact_type=artifact_type,
class_labels=class_labels,
model_config=metadata.pymodel_config if metadata else None,
input_info=input_info,
),
).data
return Model(metadata=model_data, start_server=False)
def _import_from_huggingface_hub_and_create_model(
*,
huggingface_model_id: str,
huggingface_revision: str | None,
huggingface_allow_patterns: list[str] | None,
huggingface_ignore_patterns: list[str] | None,
project_id: str,
name: str,
version: str,
summary: str = "",
artifact_type: str,
task_type: str,
class_labels: dict[str, int] | None,
input_info: Any | None = None,
) -> Model:
"""Import a model directly from Hugging Face Hub and create a model from the upload."""
from chariot_api._openapi.models.models import (
ModelsUploadsImportFromHuggingFaceHubPostRequest,
)
# Create the import request
import_response = _apis.models.models_api.models_uploads_import_from_huggingface_hub_post( # pyright: ignore [reportAttributeAccessIssue]
ModelsUploadsImportFromHuggingFaceHubPostRequest(
project_id=project_id,
huggingface_model_id=huggingface_model_id,
revision=huggingface_revision or "main",
allow_patterns=huggingface_allow_patterns or [],
ignore_patterns=huggingface_ignore_patterns or [],
),
)
upload_id = import_response.data.upload_id
log.debug(f"Hugging Face import Upload ID: {upload_id}")
assert upload_id is not None
# Use the shared polling and model creation logic
return _poll_upload_and_create_model(
upload_id=upload_id,
project_id=project_id,
name=name,
version=version,
summary=summary,
artifact_type=artifact_type,
task_type=task_type,
class_labels=class_labels,
input_info=input_info,
)
@dataclass
class S3Signature:
authorization: str
x_amz_date: str
@_apis.login_required
def sign_s3_request(url: str, method: str, headers: dict[str, str] | None = None) -> S3Signature:
"""Sign an S3 request for model download."""
data = _apis.models.models_api.sign_s3_request_post(
InputSignS3Request(url=url, method=method, headers=headers)
)
return S3Signature(authorization=data.authorization, x_amz_date=data.x_amz_date)
@_apis.login_required
def get_s3_endpoint() -> str:
"""Get S3 endpoint for model download."""
data = _apis.models.models_api.s3_endpoint_get()
return data.endpoint_url
[docs]
@mcp_setting(file_based=True, mutating=True)
@_apis.login_required
def upload_model_file(model_id: str, filename: str, data: bytes) -> str:
"""Upload a file to the model.
NOTE: Currently only supports upload of README.md and model drift detectors.
:param model_id: The model where the detector will be stored.
:type model_id: str
:param detector: The detector to upload.
:type detector: Union[SemanticDriftDetector, BatchDriftDetector]
:return: The storage url.
:rtype: str
"""
model = get_model_by_id(model_id)
data_length = len(data)
# POST - get pre-signed URL and upload ID
request_response = _apis.models.models_api.models_id_files_upload_request_get(
id=model_id,
path=filename,
filesize=data_length,
multipart=False,
)
if len(request_response) == 0:
raise RuntimeError("Upload request failed")
presigned_url = request_response[0].url
if presigned_url is None:
raise RuntimeError(f"No presigned URL returned for file {filename}")
method = request_response[0].method
if method is None:
raise RuntimeError("Upload request did not return a method")
# PUT - upload file bytes, get ETag
http = urllib3.PoolManager()
headers = {"Content-Length": str(data_length)}
upload_response = http.request(
method,
presigned_url,
body=data,
headers=headers,
)
if upload_response.status != 200:
raise RuntimeError(f"failed to upload file with response: {upload_response}")
return join_url(model._meta.uri, filename)
@contextmanager
def temporary_zipfile_from_contents(contents: dict[str, str]) -> Generator[zipfile.ZipFile]:
"""Create a temporary zipfile from a dict of filename -> file contents mappings."""
with tempfile.NamedTemporaryFile(suffix=".zip") as temp_zip:
with zipfile.ZipFile(temp_zip.name, mode="w", compression=zipfile.ZIP_STORED) as zip_file:
for filename, content in contents.items():
zip_file.writestr(filename, content)
with zipfile.ZipFile(temp_zip.name, "r") as zip_file:
yield zip_file
@contextmanager
def temporary_zipfile_from_path(path: str | Path) -> Generator[zipfile.ZipFile]:
"""Create a temporary zipfile from a given path with ZIP_STORED compression."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Path does not exist: {path}")
with tempfile.NamedTemporaryFile(suffix=".zip") as temp_zip:
with zipfile.ZipFile(temp_zip.name, "w", compression=zipfile.ZIP_STORED) as zip_file:
if path.is_file():
zip_file.write(path, path.name)
else:
for file_path in path.rglob("*"):
if file_path.is_file():
arcname = file_path.relative_to(path)
zip_file.write(file_path, arcname)
with zipfile.ZipFile(temp_zip.name, "r") as zip_file:
yield zip_file
def _any_compressed(zip_file: zipfile.ZipFile):
"""Return true if any files in the archive are compressed."""
for info in zip_file.infolist():
if not info.is_dir() and info.compress_type != zipfile.ZIP_STORED:
return True
return False
@mcp_setting(mutating=True, file_based=True)
def import_model_with_engine(
name: str,
version: str,
project_id: str,
task_type: str | TaskType,
engine_version_selector: EngineVersionSelector,
model_data: zipfile.ZipFile,
summary: str = "",
class_labels: dict[str, int] | None = None,
) -> Model:
"""Upload a model using an inference engine.
Create a model from a zip file::
with zipfile.ZipFile("my_file.zip") as zip_file:
import_model_with_engine(...)
Create a model from a directory::
with temp_zipfile_from_path("/my/model") as zip_file:
import_model_with_engine(...)
Create a model from a mapping of filenames to string contents::
with temp_zipfile_from_contents({"README.md": "example"}) as zip_file:
import_model_with_engine(...)
:param name: The model name.
:type name: str
:param version: The model version.
:type version: str
:param summary: The model summary.
:type summary: str
:param project_id: The project ID for the model.
:type project_id: str
:param task_type: The task type of the model.
:type task_type: str | TaskType
:param engine_version_selector: The engine to use for the model.
:type engine_version_selector: EngineVersionSelector
:param model_data: The model data, either from a file or the contents of a dict.
:type model_data: ModelUploadZipPath | ModelUploadData
:param class_labels: class labels for the model, if needed.
:type class_labels: dict[str, int] | None = None
:return: The newly uploaded model.
:rtype: Model
"""
if model_data.filename is None:
# This cannot happen when using the invocations in the docstring
raise ValueError("model_data zip file has no filename")
if _any_compressed(model_data):
warnings.warn("to ensure fast extraction, zip files should be created without compression")
model = import_model(
project_id=project_id,
name=name,
summary=summary,
version=version,
task_type=task_type,
artifact_type=ArtifactType.CUSTOM_ENGINE,
model_path=model_data.filename,
class_labels=class_labels,
)
model.set_inference_server_settings({"engine_selector": engine_version_selector})
return model