Source code for chariot.models.upload

import os
import tempfile
import time
import zipfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any

import joblib
import urllib3

from chariot import _apis
from chariot.config import getLogger
from chariot.models.enum import ArtifactType, TaskType
from chariot.models.model import Model
from chariot.projects.projects import _get_project_id_from_project_args
from chariot_api._openapi.models.models import (
    InputModel,
    InputSignS3Request,
    ModelsUploadsCompleteMultipartPostRequest,
    ModelsUploadsCreateMultipartPostRequest,
    ModelsUploadsCreateMultipartPostResponseData,
    ModelsUploadsIdGetResponse,
    ModelsUploadsIdGetResponseData,
)

__all__ = ["import_model"]

log = getLogger(__name__)


@contextmanager
def _temporary_zip_dir(src_dir: str):
    """Zip a directory and return the path to the ZIP.

    Context manager that zips `src_dir` (excluding dot‐files/dirs),
    yields the path to the temp ZIP, then removes it on exit.
    """
    # create a temp file for the ZIP
    tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
    tmp.close()

    try:
        # write the ZIP
        with zipfile.ZipFile(tmp.name, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
            for root, dirs, files in os.walk(src_dir):
                # skip sub-dirs starting with '.'
                dirs[:] = [d for d in dirs if not d.startswith(".")]
                for fname in files:
                    if fname.startswith("."):
                        continue
                    full_path = os.path.join(root, fname)
                    rel_path = os.path.relpath(full_path, start=src_dir)
                    zf.write(full_path, arcname=rel_path)

        yield tmp.name

    finally:
        try:
            os.remove(tmp.name)
        except OSError:
            pass


[docs] @_apis.login_required def import_model( *, name: str, version: str, summary: str, artifact_type: str | ArtifactType, task_type: str | TaskType, project_id: str | None = None, project_name: str | None = None, subproject_name: str | None = None, organization_id: str | None = None, class_labels: dict[str, int] | None = None, model_object: Any | None = None, input_info: Any | None = None, # input_info only used for sklearn models? model_path: str | None = None, use_internal_url: bool = False, ) -> Model: """Import a local model into Chariot. For a previously exported Chariot model, model_path is the local path to the gzipped tar. For a Huggingface model, model_path is either the local path to the directory or the path to a zip file containing all of the model files. """ project_id = _get_project_id_from_project_args( project_id=project_id, project_name=project_name, subproject_name=subproject_name, organization_id=organization_id, ) if (model_path and model_object) or (not model_path and not model_object): raise ValueError("Must supply model_path or model_object but not both") if class_labels is not None and not isinstance(class_labels, dict): raise ValueError("class_labels must be a dict") if not model_object: if model_path is None or not os.path.exists(model_path): raise ValueError(f"model_path {model_path!r} does not exist") if not artifact_type: raise ValueError(f"Must supply artifact_type: {ArtifactType.values()}") if isinstance(artifact_type, str): artifact_type = ArtifactType.get(artifact_type) if isinstance(task_type, str): task_type = TaskType.get(task_type) if model_object is not None: if artifact_type != ArtifactType.SKLEARN: raise ValueError( f"Unsupported model_object artifact_type: {artifact_type}. Import via model_path instead." ) tf = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib") joblib.dump(model_object, tf.name) model_path = tf.name assert model_path is not None if os.path.isdir(model_path): # NOTE(s.maddox): This code path is really only meant for Huggingface models, # but there's no fundamental reason it can't be used for other model types. with _temporary_zip_dir(model_path) as zip_path: return _upload_and_create_model( model_file_path=zip_path, project_id=project_id, name=name, version=version, summary=summary, artifact_type=artifact_type.value, task_type=task_type.value, class_labels=class_labels, input_info=input_info, use_internal_url=use_internal_url, ) return _upload_and_create_model( model_file_path=model_path, project_id=project_id, name=name, version=version, summary=summary, artifact_type=artifact_type.value, task_type=task_type.value, class_labels=class_labels, input_info=input_info, use_internal_url=use_internal_url, )
def _upload_and_create_model( *, model_file_path: str, project_id: str, name: str, version: str, summary: str, artifact_type: str, task_type: str, class_labels: dict[str, int] | None, input_info: Any | None = None, # input_info only used for sklearn models? use_internal_url: bool, ): # Create the upload create_upload_data: ModelsUploadsCreateMultipartPostResponseData = ( _apis.models.models_api.models_uploads_create_multipart_post( # pyright: ignore [reportAttributeAccessIssue] ModelsUploadsCreateMultipartPostRequest( file_name=os.path.basename(model_file_path), file_size=os.path.getsize(model_file_path), project_id=project_id, use_internal_url=use_internal_url, ), ) ).data # create_upload_data: ModelsUploadsCreateMultipartPostResponseData = create_upload_response.data # pyright: ignore [reportAssignmentType] upload_id = create_upload_data.upload_id log.debug(f"Upload ID: {upload_id}") part_size = create_upload_data.part_size signed_urls = create_upload_data.signed_urls assert upload_id is not None assert part_size is not None assert signed_urls is not None # Upload the model archive in parts pool_manager = urllib3.PoolManager() etags: list[str] = [] with open(model_file_path, "rb") as fp: for idx, url in enumerate(signed_urls, start=1): part = fp.read(part_size) if not part: break # no more data resp = pool_manager.request( "PUT", url, body=part, # headers={'Content-Length': str(len(part))} ) if resp.status != 200: # TODO(s.maddox): retry with backoff? raise RuntimeError(f"Upload failed for part {idx}: HTTP {resp.status}") # grab the ETag (case‐insensitive) etag = resp.headers.get("etag") if etag is None: raise RuntimeError(f"No ETag returned for part {idx}") etags.append(etag) # Complete the upload _apis.models.models_api.models_uploads_complete_multipart_post( # pyright: ignore [reportAttributeAccessIssue] ModelsUploadsCompleteMultipartPostRequest( upload_id=upload_id, etags=etags, ), ) # Poll the upload until it's `extracted` or `extraction_error`, or until there are no progress updates for 120 seconds. while True: upload_response: ModelsUploadsIdGetResponse upload_response = _apis.models.models_api.models_uploads_id_get(id=upload_id) # pyright: ignore [reportAttributeAccessIssue] assert upload_response.data upload: ModelsUploadsIdGetResponseData = upload_response.data if upload.status == "extracted": log.debug(f"Upload {upload_id} extracted paths: {upload.extracted_paths}") break if upload.status == "extraction_error": raise RuntimeError(f"upload extraction error: {upload.extraction_error}") assert upload.updated_at is not None updated_at = datetime.fromtimestamp(upload.updated_at / 1000.0, tz=UTC) if updated_at < datetime.now(tz=UTC) - timedelta(seconds=120): raise RuntimeError( f"Upload {upload_id} has not been updated for 120 seconds. " f"Current status: {upload.status!r}" ) log.debug( f"Upload {upload_id} status: {upload.status!r}, extracted bytes: {upload.extracted_bytes} / {upload.total_bytes}" ) time.sleep(2) # wait before polling again metadata = upload.extracted_metadata if metadata: if metadata.artifact_type and metadata.artifact_type != artifact_type: log.warning( f"Specified artifact type, {artifact_type!r}, does not match the type from the archive, {metadata.artifact_type!r}." f" Using the specified artifact type, {artifact_type!r}." ) if metadata.task_type and metadata.task_type != task_type: log.warning( f"Specified task type, {task_type!r}, does not match the type from the archive, {metadata.task_type!r}." f" Using the specified task type, {task_type!r}." ) if class_labels is None and metadata.class_labels is not None: # NOTE(s.maddox): Presumably, if the model archive contains class labels # and the user didn't specify an override, then we should use the # class labels from the archive. class_labels = metadata.class_labels # Create the model from the upload model_data = _apis.models.models_api.projects_project_models_post( # pyright: ignore [reportAttributeAccessIssue] project=project_id, body=InputModel( from_upload_id=upload_id, name=name, version=str(version), summary=summary, task_type=task_type, artifact_type=artifact_type, class_labels=class_labels, model_config=metadata.pymodel_config if metadata else None, input_info=input_info, ), ).data return Model(metadata=model_data, start_server=False) @dataclass class S3Signature: authorization: str x_amz_date: str @_apis.login_required def sign_s3_request(url: str, method: str, headers: dict[str, str] | None = None) -> S3Signature: """Sign an S3 request for model download.""" data = _apis.models.models_api.sign_s3_request_post( InputSignS3Request(url=url, method=method, headers=headers) ) return S3Signature(authorization=data.authorization, x_amz_date=data.x_amz_date) @_apis.login_required def get_s3_endpoint() -> str: """Get S3 endpoint for model download.""" data = _apis.models.models_api.s3_endpoint_get() return data.endpoint_url