Source code for chariot.models.upload

import copy
import io
import json
import os
import re
import tarfile
import tempfile
from collections.abc import Iterator
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, TypeAlias

import joblib

from chariot import _apis
from chariot.config import getLogger
from chariot.models import ModelDoesNotExistError
from chariot.models._chariot_storage import upload_files_via_stream
from chariot.models.enum import ArtifactType, TaskType
from chariot.models.model import Model, task_from_teddy_config
from chariot.projects.projects import _get_project_id_from_project_args
from chariot_api._openapi.models.models import InputModel, InputSignS3Request

__all__ = ["import_model"]

log = getLogger(__name__)


@_apis.login_required
def _create_model(project_id, payload):
    """Creates metadata for model"""
    log.info(f"POST /models project_id={project_id}")
    data = _apis.models.models_api.projects_project_models_post(
        project=project_id, body=payload
    ).data
    return Model(metadata=data, start_server=False)


def _generate_create_chariot_model_payload(
    name: str,
    version: str,
    summary: str,
    task_type: str,
    model_conf: dict,
    artifact_type: ArtifactType,
    training_run: str | None = None,
    training_run_checkpoint: str | None = None,
    training_run_project: str | None = None,
    fork_model_id: str | None = None,
) -> dict:
    if artifact_type not in ArtifactType:
        raise ValueError(f"Unknown artifact_type: {artifact_type!r}")

    return InputModel(
        name=name,
        version=str(version),
        summary=summary,
        task_type=task_type,
        artifact_type=artifact_type.value,
        model_config=model_conf,
        training_run=training_run,
        training_run_checkpoint=training_run_checkpoint,
        training_run_project=training_run_project,
        fork_model_id=fork_model_id,
    ).to_dict()


def _generate_create_non_chariot_model_payload(
    name: str,
    version: str,
    summary: str,
    task_type: TaskType,
    artifact_type: ArtifactType,
    class_labels: dict | None = None,
    **kwargs,
) -> dict:
    """Generates the payload to be used when creating a non Chariot model"""
    return InputModel(
        name=name,
        version=version,
        summary=summary,
        artifact_type=artifact_type.value,
        task_type=task_type.value,
        class_labels=class_labels,
        **kwargs,
    ).to_dict()


@_apis.login_required
def _add_model_with_teddy_config_to_models_service(
    teddy_model_conf: dict,
    summary: str,
    name: str,
    version: str,
    project_id: str,
    task_type: str | None = None,
    artifact_type: ArtifactType = ArtifactType.CHARIOT,
    training_run: str | None = None,
    training_run_checkpoint: str | None = None,
    training_run_project: str | None = None,
) -> Model:
    if task_type is None:
        task_type = task_from_teddy_config(teddy_model_conf).value
    payload = _generate_create_chariot_model_payload(
        model_conf=teddy_model_conf,
        summary=summary,
        name=name,
        version=version,
        task_type=task_type,
        artifact_type=artifact_type,
        training_run=training_run,
        training_run_checkpoint=training_run_checkpoint,
        training_run_project=training_run_project,
    )

    return _create_model(project_id, payload)


def _handle_upload_model_exception(exc, model_id, project_id):
    log.error(f"Error when uploading file {exc} -- deleting {model_id}")
    _apis.models.models_api.projects_project_models_id_delete(project=project_id, id=model_id)
    raise exc


def _process_file(
    stack: ExitStack | None, file_path: str, base: str
) -> tuple[dict[str, dict[str, int]], dict[str, Any]]:
    """Process the specified file from its file path and return the size and file handles.

    Parameters
    ----------
    stack:
        An optional ExitStack instance for resource management.
    file_path:
        Path to the file to be processed.
    base:
        the base directory path of a file, which will be removed from the relative path

    Returns
    -------
    files:
        Dictionary of file names to be `upload` and their size. Will be sent as
        `body.files` to model catalog service. Expected format `{file_name: {"file_size": file_size}}}`
    file_handles:
        Dictionary of file handles opened as 'rb'. Expected format `{file_name: file_handle}`

    """
    file_name = os.path.relpath(file_path, base)
    file_size = os.path.getsize(file_path)
    if stack:
        file_handle = stack.enter_context(open(file_path, "rb"))
    else:
        file_handle = open(file_path, "rb")
    return {file_name: {"file_size": file_size}}, {file_name: file_handle}


def _process_directory_files(
    directory_path: str, stack: ExitStack | None = None
) -> tuple[dict[str, int], dict[str, Any]]:
    """Process the specified files from a directory and return the size and file handles.

    Parameters
    ----------
    directory_path:
        Path to directory containing the files to be processed.
    stack:
        An optional ExitStack instance for resource management.

    Returns
    -------
    files: dict
        Dictionary of file names to be uploaded and their size. Will be sent as
        `body.files` to model catalog service. Expected format `{file_name: {"file_size": file_size}}}`
    file_handles: dict
        Dictionary of file handles opened as 'rb'. Expected format `{file_name: file_handle}`

    """
    files = {}
    file_handles = {}
    for root, dirs, filenames in os.walk(directory_path):
        # os.walk recursively traverses the subdirectories listed in dirs.
        # By modifying the dirs list in place to not contain directories
        # that start with "." we do not visit any hidden directories.
        dirs[:] = [d for d in dirs if d[0] != "."]
        for f in filter_files(filenames):
            file_path = os.path.join(root, f)
            file_data, file_handle = _process_file(stack, file_path, directory_path)
            files.update(file_data)
            file_handles.update(file_handle)
    return files, file_handles


def _process_tar_files(
    tar: tarfile.TarFile, filenames: Iterator[str]
) -> tuple[dict[str, int], dict[str, Any]]:
    """Process the specified files from a file archive and return their sizes and file handles.

    Parameters
    ----------
    tar:
        Tarfile object containing files to be processed.
    filenames:
        Iterator for filenames to be processed.

    Returns
    -------
    files: dict
        Dictionary of file names to be upload and their size. Will be send as
        `body.files` to model catalog service. Expected format `{file_name: {"file_size": file_size}}}`
    file_handles: dict
        Dictionary of file names and their in memory binary streams. Expected format `{file_name: file_handle}`

    """
    files = {}
    file_handles = {}
    size_map = {t.name: t.size for t in tar.getmembers()}

    # Construct a common path prefix to remove
    first_name = min(size_map.keys()).split("/")
    last_name = max(size_map.keys()).split("/")
    base = None
    if len(first_name) > 1 or len(last_name) > 1:
        for fn, ln in zip(first_name, last_name):
            if fn != ln:
                break
            base = os.path.join(base, fn) if base else fn
        if base:
            log.info(f"stripping leading common prefix {base}")

    for filename in filenames:
        try:
            member = tar.getmember(filename)

            log.info(f"processing {member}")
            if not member.isfile():
                log.info(f"{member.name} is not a file, skipping.")
                continue
            with tar.extractfile(filename) as extracted_file:
                filename_rel = os.path.relpath(filename, base)
                file_handles[filename_rel] = io.BytesIO(extracted_file.read())
                files[filename_rel] = {"file_size": size_map[filename]}
        except KeyError:
            raise ValueError(f"file {filename} was not found in {size_map.keys()}.")
    return files, file_handles


_ChariotModelConfigDict: TypeAlias = dict[str, Any]


def _extract_model_config_dict_from_chariot_model_tgz(
    chariot_model_tgz_path: str,
) -> _ChariotModelConfigDict:
    with tarfile.open(chariot_model_tgz_path, "r:gz") as tar:
        # first, check for .chariot/model.json at any depth
        for member in tar.getmembers():
            if member.name == ".chariot/model.json" or member.name.endswith("/.chariot/model.json"):
                with tar.extractfile(member) as fp:
                    return json.load(fp)["model_config"]
        # if that doesn't exist, then check for the legacy config.json
        for member in tar.getmembers():
            if member.name == "config.json":
                with tar.extractfile(member) as fp:
                    return json.load(fp)
    raise ValueError(f"Could not find '.chariot/model.json' in {chariot_model_tgz_path!r}.")


def _extract_state_dict_paths(model_config_dict: _ChariotModelConfigDict) -> list[str]:
    state_dict_paths = []

    def _extract_state_dict_paths_recursive(d: dict):
        for k, v in d.items():
            if k == "#state_dict_path":
                state_dict_paths.append(v)
            if isinstance(v, dict):
                _extract_state_dict_paths_recursive(v)

    _extract_state_dict_paths_recursive(model_config_dict)
    return state_dict_paths


def _replace_state_dict_paths(
    model_config_dict: _ChariotModelConfigDict, subs: dict[str, str]
) -> _ChariotModelConfigDict:
    def _replace_state_dict_paths_recursive(d: dict):
        for k, v in d.items():
            if k == "#state_dict_path":
                d[k] = subs.get(v, v)
            if isinstance(v, dict):
                _replace_state_dict_paths_recursive(v)

    model_config_dict = copy.deepcopy(model_config_dict)
    _replace_state_dict_paths_recursive(model_config_dict)
    return model_config_dict


def _import_chariot_model(
    tar_path: str,
    project_id: str,
    name: str,
    version: str,
    summary: str,
    task_type: TaskType,
    artifact_type: ArtifactType,
    use_internal_url: bool,
) -> Model:
    assert artifact_type in {ArtifactType.CHARIOT, ArtifactType.NEURALMAGIC}

    model_config_dict = _extract_model_config_dict_from_chariot_model_tgz(tar_path)
    state_dict_paths = _extract_state_dict_paths(model_config_dict)
    if not state_dict_paths:
        raise ValueError(f"Found zero model weight files to upload in {tar_path!r}")
    if len(state_dict_paths) > 1:
        raise ValueError(
            f"Found multiple model weight files to upload in {tar_path!r}; expected only one"
        )
    state_dict_path = state_dict_paths[0]
    file_name = "model.pth"
    model_config_dict = _replace_state_dict_paths(model_config_dict, {state_dict_path: file_name})

    model = _add_model_with_teddy_config_to_models_service(
        task_type=task_type.value,
        artifact_type=artifact_type,
        teddy_model_conf=model_config_dict,
        summary=summary,
        name=name,
        version=version,
        project_id=project_id,
    )
    log.info(f"created model {model.id}")
    if not model._meta.artifact_type:
        # in case the Create did not return a complete model
        model = Model(project_id=project_id, id=model.id, start_server=False)

    with ExitStack() as stack:
        (
            files,
            file_handles,
        ) = _get_file_and_file_handles_for_non_chariot_tar(
            model_path=tar_path,
            ignore_patterns=["config.json", ".chariot/model.json"],
            allow_patterns=None,
            expected_file_count=None,
        )

        # model.pth handling: only one file with fixed name
        model_pth_filenames = [
            key
            for key in files.keys()
            if key == file_name or key.endswith(f"/{file_name}") or key == "net.pth"
        ]
        if not model_pth_filenames:
            raise ValueError(f"Could not find {file_name} in {tar_path!r}")
        if len(model_pth_filenames) > 1:
            raise ValueError(f"Found multiple {file_name} files in {tar_path!r}")
        key = model_pth_filenames[0]
        if key != file_name:
            # NOTE for NM models it might be net.pth but we rename that model.pth
            files[file_name] = files[key]
            del files[key]
            file_handles[file_name] = file_handles[key]
            del file_handles[key]

        # All opened files will automatically be closed at the end of
        # the with statement, even if attempts to open files later
        # in the list raise an exception
        try:
            log.info(f"uploading {files}")
            resp = upload_files_via_stream(
                files,
                file_handles,
                model,
                use_internal_url=use_internal_url,
            )
            log.info(f"Succesfully uploaded '{resp.uri}' for model {resp.name}:{resp.version}")
        except Exception as exc:
            _handle_upload_model_exception(exc, model.id, project_id)

    data = _apis.models.models_api.models_get(id=model.id).data
    if len(data) == 0:
        raise ModelDoesNotExistError(model.id, version)
    return Model(metadata=data[0], start_server=False)


def _init_patterns(
    allow_patterns: list[str] | list[re.Pattern] | str | None = None,
    ignore_patterns: list[str] | list[re.Pattern] | str | None = None,
) -> list[str] | list[re.Pattern]:
    """If no patterns specified returns empty array"""
    if not allow_patterns:
        allow_patterns = []
    if not ignore_patterns:
        ignore_patterns = []
    if isinstance(allow_patterns, str):
        allow_patterns = [allow_patterns]
    if isinstance(ignore_patterns, str):
        ignore_patterns = [ignore_patterns]
    return allow_patterns, ignore_patterns


def _is_private_file(filepath: str):
    return os.path.basename(filepath).startswith(".")


def _filter_members(
    members: list[str],
    allow_patterns: list[re.Pattern] | None = None,
    ignore_patterns: list[re.Pattern] | None = None,
) -> Iterator[str]:
    allow_patterns, ignore_patterns = _init_patterns(
        allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
    )

    def is_allowed(member):
        if _is_private_file(member):
            return False
        if ignore_patterns and any(pattern.match(member) for pattern in ignore_patterns):
            return False
        if allow_patterns and not any(pattern.match(member) for pattern in allow_patterns):
            return False
        return True

    log.info(
        f"filtering members: {members}, allow patterns: {allow_patterns}, ignore_patterns: {ignore_patterns}, and ignoring private files"
    )
    return filter(is_allowed, members)


def _get_file_and_file_handles_for_non_chariot_tar(
    model_path: str,
    ignore_patterns: str | list[str] | None = None,
    allow_patterns: str | list[str] | None = None,
    expected_file_count: int | None = None,
) -> tuple[dict[str, int], dict[str, Any]]:
    """Gets the files and file handles for the files within a tarfile"""
    allow_patterns, ignore_patterns = _init_patterns(
        allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
    )
    allow_patterns = [re.compile(p) for p in allow_patterns]
    ignore_patterns = [re.compile(p) for p in ignore_patterns]

    with tarfile.open(model_path, "r:gz") as tar:
        filtered_members = _filter_members(
            tar.getnames(),
            allow_patterns=allow_patterns,
            ignore_patterns=ignore_patterns,
        )
        files, file_handles = _process_tar_files(tar, filtered_members)
        if expected_file_count and len(files.keys()) != expected_file_count:
            raise ValueError(
                f"Unexpected amount of files found. Expected {expected_file_count} file/s but found {files.keys()}"
            )
        return files, file_handles


def _get_file_and_file_handles_for_file(
    stack: ExitStack, model_path: str
) -> tuple[dict[str, int], dict[str, Any]]:
    """Gets the files and file handles given the path to a model file"""
    return _process_file(stack, model_path, os.path.dirname(model_path))


def _get_file_and_file_handles_for_folder(
    stack: ExitStack, model_path: str
) -> tuple[dict[str, int], dict[str, Any]]:
    model_path = os.path.join(model_path, "")
    return _process_directory_files(model_path, stack)


def _upload_non_chariot_model_to_s3(
    artifact_type: ArtifactType,
    model_path: str,
    model: Model,
    use_internal_url: bool,
    allow_patterns: list[str] | None = None,
    ignore_patterns: list[str] | None = None,
):
    allow_patterns, ignore_patterns = _init_patterns(
        allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
    )
    expected_file_count = None

    with ExitStack() as stack:
        if os.path.isfile(model_path) and tarfile.is_tarfile(model_path):
            log.info(f"artifact type: {artifact_type} and is a tarfile")
            (
                files,
                file_handles,
            ) = _get_file_and_file_handles_for_non_chariot_tar(
                model_path=model_path,
                ignore_patterns=ignore_patterns,
                allow_patterns=allow_patterns,
                expected_file_count=expected_file_count,
            )
        elif os.path.isfile(model_path) and not tarfile.is_tarfile(model_path):
            log.info(f"artifact type: {artifact_type} and is not a tarfile")
            if artifact_type == ArtifactType.HUGGINGFACE:
                raise ValueError(
                    "For huggingface models, supply a directory or tar containing the contents to upload"
                )

            files, file_handles = _get_file_and_file_handles_for_file(stack, model_path)
        else:
            log.info(f"artifact type: {artifact_type} and is a folder")
            files, file_handles = _get_file_and_file_handles_for_folder(stack, model_path)

        # All opened files will automatically be closed at the end of
        # the with statement, even if attempts to open files later
        # in the list raise an exception
        log.info(f"uploading {files}")
        resp = upload_files_via_stream(
            files,
            file_handles,
            model,
            use_internal_url=use_internal_url,
        )
        log.info(f"Succesfully uploaded '{resp.uri}' for model {resp.name}:{resp.version}")


def _handle_model_object(
    artifact_type: ArtifactType,
    model_path: str | None = None,
    model_object: Any | None = None,
) -> tuple[str, tempfile._TemporaryFileWrapper | None]:
    """Handles creating temporary file to store model_object if artifact is sklearn

    Parameters
    ----------
    artifact_type:
        Model artifact type.
    model_path:
        File path to model or None if does not exist.
    model_object:
        Model object in memory or None if does not exist.

    Returns
    -------
    model_path
        File path to model.
    tf
        Temporary file to store model if it was previously in memory, or None if a model file path already exists.

    """
    if model_path:
        return model_path, None

    if artifact_type == ArtifactType.SKLEARN:
        tf = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
        joblib.dump(model_object, tf.name)
        model_path = tf.name
        return model_path, tf

    raise ValueError(
        f"Unsupported {artifact_type} model object import.Import via model_path instead."
    )


def _import_non_chariot_model(
    name: str,
    version: str,
    summary: str,
    task_type: TaskType,
    artifact_type: ArtifactType,
    project_id: str | None = None,
    class_labels: dict | None = None,
    model_object: Any | None = None,
    model_path: str | None = None,
    use_internal_url: bool = False,
    **kwargs,
) -> Model:
    input_model_dict = _generate_create_non_chariot_model_payload(
        name=name,
        version=version,
        summary=summary,
        artifact_type=artifact_type,
        task_type=task_type,
        class_labels=class_labels,
        **kwargs,
    )

    model = _create_model(project_id, payload=input_model_dict)
    log.info(f"created model {model.id}")

    try:
        model_path, tf = _handle_model_object(artifact_type, model_path, model_object)
        _upload_non_chariot_model_to_s3(artifact_type, model_path, model, use_internal_url)
    except Exception as exc:
        _handle_upload_model_exception(exc, model.id, project_id)
    finally:
        if tf:
            os.unlink(tf.name)

    data = _apis.models.models_api.projects_project_models_id_get(
        project=project_id, id=model.id
    ).data

    return Model(metadata=data, start_server=False)


def check_should_upload_file(filename: str) -> bool:
    """Check whether file should be uploaded as part of model upload or not"""
    return not filename.startswith(".")


def filter_files(filenames: list[str]) -> Iterator[str]:
    """Return only files that should be uploaded as part of model upload"""
    return filter(check_should_upload_file, filenames)


[docs] def import_model( *, name: str, version: str, summary: str, artifact_type: str | ArtifactType, task_type: str | TaskType, project_id: str | None = None, project_name: str | None = None, subproject_name: str | None = None, organization_id: str | None = None, class_labels: dict | None = None, model_object: Any | None = None, model_path: str | None = None, use_internal_url: bool = False, **kwargs, ) -> Model: """Import a local model into Chariot. For a previously exported Chariot model, model_path is the local path to the gzipped tar """ project_id = _get_project_id_from_project_args( project_id=project_id, project_name=project_name, subproject_name=subproject_name, organization_id=organization_id, ) if (model_path and model_object) or (not model_path and not model_object): raise ValueError("Must supply model_path or model_object but not both") if class_labels is not None and not isinstance(class_labels, dict): raise ValueError("class_labels must be a dict") if not model_object: if model_path is None or not os.path.exists(model_path): raise ValueError(f"model_path {model_path} does not exist") if not artifact_type: raise ValueError(f"Must supply artifact_type: {ArtifactType.values()}") if isinstance(artifact_type, str): artifact_type = ArtifactType.get(artifact_type) if isinstance(task_type, str): task_type = TaskType.get(task_type) # TODO should neural magic have their own import function or follow the chariot path? if artifact_type in {ArtifactType.CHARIOT, ArtifactType.NEURALMAGIC}: return _import_chariot_model( tar_path=model_path, project_id=project_id, name=name, version=version, summary=summary, task_type=task_type, artifact_type=artifact_type, use_internal_url=use_internal_url, ) if artifact_type in { ArtifactType.SKLEARN, ArtifactType.PYTORCH, ArtifactType.ONNX, ArtifactType.HUGGINGFACE, }: return _import_non_chariot_model( name=name, version=version, summary=summary, task_type=task_type, artifact_type=artifact_type, project_id=project_id, class_labels=class_labels, model_object=model_object, model_path=model_path, use_internal_url=use_internal_url, **kwargs, ) # TODO(spollack) other uploads raise NotImplementedError(f"TODO Implement {artifact_type}")
@dataclass class S3Signature: authorization: str x_amz_date: str @_apis.login_required def sign_s3_request(url: str, method: str, headers: dict[str, str] | None = None) -> S3Signature: """Signs an S3 request for model upload/download""" data = _apis.models.models_api.sign_s3_request_post( InputSignS3Request(url=url, method=method, headers=headers) ) return S3Signature(authorization=data.authorization, x_amz_date=data.x_amz_date) @_apis.login_required def get_s3_endpoint() -> str: """Gets S3 endpoint for model upload/download""" data = _apis.models.models_api.s3_endpoint_get() return data.endpoint_url