Source code for chariot.models.model

import gzip
import json
import time
import warnings
from collections.abc import Sequence
from datetime import UTC, datetime
from enum import StrEnum
from http import HTTPStatus
from io import BytesIO
from typing import Any, Self

import requests
from deprecated import deprecated
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from chariot import _apis
from chariot.config import get_bearer_token, getLogger, settings
from chariot.inference_store.metadata import CustomMetadata
from chariot.models import _kserve
from chariot.models.enum import (
    ArtifactType,
    InferenceEngine,
    TaskType,
    TaskTypesInferenceMethod,
)
from chariot.models.inference import ACTION_DOCSTRINGS, Action
from chariot.models.isvc_settings import (
    GPUDict,
    InferenceServerSettingsDict,
    VLLMConfigurationDict,
    create_isvc_settings,
    get_inference_server_settings,
    set_inference_server_settings,
)
from chariot.resources import Resource
from chariot_api._openapi.models.models.output_inference_service import (
    OutputInferenceService,
)
from chariot_api._openapi.models.models.output_model import OutputModel
from chariot_api._openapi.models.models.output_model_summary import OutputModelSummary
from chariot_api._openapi.models.rest import ApiException

__all__ = [
    "GPUDict",
    "InferenceServerSettingsDict",
    "InferenceServerStatus",
    "Model",
    "ModelDoesNotExistError",
    "VLLMConfigurationDict",
    "get_catalog",
    "get_model_by_id",
    "get_models",
    "iter_models",
    "ActionUnsupportedByCurrentModelError",
]

log = getLogger(__name__)

DEFAULT_MAX_RETRIES = 5


class identity(dict):
    @staticmethod
    def __missing__(key):
        return key


# TODO MOD-280 these should eventually not be hard-coded (teddy knows this mapping already)
def task_from_teddy_config(model_conf: dict) -> TaskType:
    class_to_task = {
        "TorchvisionDetector": TaskType.OBJECT_DETECTION,
        "CenterNetDetector": TaskType.OBJECT_DETECTION,
        "SimpleClassImageModel": TaskType.IMAGE_CLASSIFICATION,
        "EmbedModel": TaskType.IMAGE_EMBEDDING,
        "TorchvisionSemanticSegmentationModel": TaskType.IMAGE_SEGMENTATION,
        "SemanticSegmentationModel": TaskType.IMAGE_SEGMENTATION,
        "ImageClassificationModel": TaskType.IMAGE_CLASSIFICATION,
        "RetinaFaceDetector": TaskType.OBJECT_DETECTION,
        "VariationalAutoEncoder": TaskType.IMAGE_AUTOENCODER,
        "ClfTransformersModel": TaskType.TEXT_CLASSIFICATION,
        "NERTransformersModel": TaskType.TOKEN_CLASSIFICATION,
        "YOLO": TaskType.OBJECT_DETECTION,
    }
    return class_to_task[model_conf["#class_name"].split(".")[-1]]


# Values come directly from models-catalog.
# These should be kept in sync manually until we find a better way to handle this.
# https://github.com/Striveworks/chariot/blob/main/go/apps/models-catalog/pkg/constants/inference_server.go#L4-L11
[docs] class InferenceServerStatus(StrEnum): """The status of an inference server.""" NULL = "null" # means no inference server has been requested. This one does not exist in models-catalog. READY = "Ready" ERROR = "Error" SCALED_DOWN = "Scaled down" STARTING = "Starting" UPDATING = "Updating" NOTINITIALIZED = "Not initialized" PENDING = "Pending" UNKNOWN = "Unknown" MISCONFIGURED = "Misconfigured"
[docs] class ActionUnsupportedByCurrentModelError(Exception): pass
def create_inference_method(cls, action: str): """Curry of Model.infer. Used to dynamically add inference methods to Model.""" def method( sample: Any, timeout: int = 60, verbose: bool = False, return_inference_id: bool = False, **inference_kwargs, ) -> dict: return cls.infer( action, sample, timeout=timeout, verbose=verbose, return_inference_id=return_inference_id, **inference_kwargs, ) return method def _get_session(**kwargs): adapter = HTTPAdapter( max_retries=Retry( total=kwargs.get("max_retries", DEFAULT_MAX_RETRIES), connect=kwargs.get("max_retries", DEFAULT_MAX_RETRIES), read=kwargs.get("max_retries", DEFAULT_MAX_RETRIES), allowed_methods=None, backoff_factor=1.5, ) ) session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) session.cookies.set("access_token", get_bearer_token()) return session def _post_request(url, payload, artifact_type: ArtifactType, **kwargs): try: session = _get_session(**kwargs) preds = session.post( url, verify=settings.verify_ssl, json=payload, ) preds.raise_for_status() if artifact_type == ArtifactType.HUGGINGFACE: try: buf = BytesIO(preds.content) gzip_f = gzip.GzipFile(fileobj=buf) content = gzip_f.read() return json.loads(content.decode()) # TODO clean up when moving to teddy doing inference. # Right now it appers only text-classification and small token-classification don't gzip. except gzip.BadGzipFile: return preds.json() else: return preds.json() except requests.exceptions.HTTPError as e: raise RuntimeError(f"Inference request failed: {e} {e.response.text}")
[docs] class ModelDoesNotExistError(Exception): def __init__(self, model_name: str, version: str | None = None): super().__init__( f"Model {model_name!r}{(' version ' + version) if version else ''} does not exist or you do not have access." ) self.model_name = model_name self.version = version
class ModelUploadTimeoutError(Exception): pass
[docs] class Model(Resource): task_to_method = { T.value: getattr(TaskTypesInferenceMethod, T.name).value for T in TaskType if getattr(TaskTypesInferenceMethod, T.name).value } default_action_methods = ["predict"] def __init__( self, project_id: str | None = None, id: str | None = None, project_name: str | None = None, subproject_name: str | None = None, organization_id: str | None = None, name: str | None = None, version: str | None = None, metadata: Any = None, start_server: bool = True, ): super().__init__( project_id=project_id, id=id, project_name=project_name, subproject_name=subproject_name, organization_id=organization_id, name=name, version=version, metadata=metadata, ) model_config = self._meta.pymodel_config self._artifact_type = ArtifactType.get(self._meta.artifact_type or "") if not self._meta.task_type and self._artifact_type == ArtifactType.CHARIOT: self.task = task_from_teddy_config(model_config) else: self.task = TaskType.get(self._meta.task_type or "") if start_server: self.start_inference_server() self._populate_class_inference_methods() # auto-populate inference methods self._isvc = None self._use_internal_url = False # whether to cluster DNS address self._internal_url = None self._external_url = None self._inference_engine = None @property def inference_url(self) -> str | None: """Url to inference server inference endpoint.""" return self._internal_url if self._use_internal_url else self._external_url def _populate_class_inference_methods(self): """Set the available inference methods as class methods.""" if self.actions is None: log.error("Unable to set inference methods") for action in self.actions: method = create_inference_method(self, action) method.__doc__ = ACTION_DOCSTRINGS.get(action, f"Run `{action}` on the model") method.__name__ = action setattr(self, method.__name__, method) @staticmethod def _get_resource_meta(project_id: str, id: str) -> OutputModel: data = _apis.models.models_api.models_get(project_ids=project_id, id=id).data if len(data) == 0: raise ModelDoesNotExistError(project_id, id) return data[0] @staticmethod def _get_id_by_name_and_version(project_id: str, name: str, version: str) -> str: kwargs = {"project_ids": project_id, "model_name": name} if version is not None: kwargs["model_version"] = version models = _apis.models.models_api.models_get(**kwargs).data if len(models) == 0: raise ModelDoesNotExistError(name, version) return models[0].id @property def status(self): status, _ = self._get_inference_server_status_and_resp() if status is not InferenceServerStatus.READY: log.debug(f"Inference service is not ready yet for {self}") else: log.debug( f"Inference server already exists for {self} and is currently in status {status.value}." ) return status @property def architecture(self): return self._meta.architecture @property def name(self): return self._meta.name @property def name_slug(self): return self._meta.name_slug @property def version(self): return self._meta.version @property def created_at(self) -> datetime: return datetime.fromtimestamp(int(self._meta.created_at) / 1000, tz=UTC) @property def storage_status(self): return self._meta.storage_status @property def actions(self) -> list[str]: return self._meta.inference_methods or self.default_action_methods def __str__(self): return f"Model[{self.name} v{self.version} {self.id}]" @property def class_labels(self): if self._meta.class_labels is None or len(self._meta.class_labels) == 0: return None return self._meta.class_labels @property def inverse_class_labels(self): icl = identity() if self.class_labels is None or not isinstance(self.class_labels, dict): return icl icl.update({v: k for k, v in self.class_labels.items()}) return icl
[docs] def wait_for_upload(self, timeout=60, wait_interval=2) -> Self: """Wait for `timeout` seconds for this model to be uploaded. Parameters ---------- timeout: Number in seconds to wait for model upload. wait_interval: How many seconds to wait after each query before trying again. Returns ------- The Model object, for chaining Raises ------ ModelUploadTimeoutError: If `storage_status` is not "uploaded" before `timeout`. """ storage_status = self._meta.storage_status t_start = time.monotonic() while True: if storage_status == "uploaded": return self if storage_status == "failed": raise ModelUploadTimeoutError( f"Model {self.id} has storage status 'failed', expected 'uploaded'" ) if time.monotonic() > t_start + timeout: raise ModelUploadTimeoutError( f"Model {self.id} has storage status '{storage_status}' after {timeout}s, expected 'uploaded'" ) time.sleep(wait_interval) self._meta = self._get_resource_meta(project_id=self.project_id, id=self.id) storage_status = self._meta.storage_status
[docs] def delete(self, hard_delete: bool = False): """Delete this model.""" _apis.models.models_api.projects_project_models_id_delete( project=self.project_id, id=self._meta.id, hard_delete=hard_delete ) log.info(f"Deleted {self}")
@property @deprecated( reason="The `isvc_settings` property is deprecated and will be removed in a future release. " "Please use the `get_inference_server_settings` method instead." ) def isvc_settings(self): """Get the current settings for this model's inference server. NOTE: This property is deprecated and will be removed in a future release. Please use the :meth:`~chariot.models.model.Model.get_inference_server_settings` method instead. """ return self.get_inference_server_settings()
[docs] @deprecated( reason="The update_isvc_settings method is deprecated and will be removed in a future release. " "Please use the `set_inference_server_settings` method instead." ) def update_isvc_settings(self, settings): """Updatesettings for this model's inference server. NOTE: This method is deprecated and will be removed in a future release. Please use the :meth:`~chariot.models.model.Model.set_inference_server_settings` method instead. """ self.set_inference_server_settings(settings)
# add property for status of inference server def _get_inference_server_status_and_resp( self, ) -> tuple[InferenceServerStatus, OutputInferenceService | None]: # Returns an aggregate status of all inference sever pods try: resp = _apis.models.models_api.projects_project_models_id_inferenceservices_get( project=self.project_id, id=self._meta.id ).data except ApiException as e: if e.status == 404: return InferenceServerStatus.NULL, None raise e self._isvc = resp return InferenceServerStatus(resp.inference_server_status), resp
[docs] def get_inference_server_settings(self) -> InferenceServerSettingsDict: """Get the current settings for this model's inference server. Returns ------- :class:`chariot.models.model.InferenceServerSettingsDict` """ model_id = self._meta.id assert model_id return get_inference_server_settings(model_id)
[docs] def set_inference_server_settings(self, settings: InferenceServerSettingsDict): """Set settings for this model's inference server. Parameters ---------- settings: :class:`chariot.models.model.InferenceServerSettingsDict` Settings to apply to the inference server. """ model_id = self._meta.id assert model_id set_inference_server_settings(model_id, settings)
[docs] def start_inference_server( self, # Deprecated parameters: cpu: str | None = None, num_workers: int | None = None, memory: str | None = None, min_replicas: int | None = None, max_replicas: int | None = None, scale_metric: str | None = None, scale_target: int | None = None, scale_down_delay: int | str | None = None, gpu_count: int | None = None, gpu_type: str | None = None, edit_existing: bool | None = None, quantization_bits: int | None = None, huggingface_model_kwargs: dict | None = None, inference_engine: InferenceEngine | str | None = None, vllm_config: dict | None = None, ): """Create an inference server for the model object. NOTE: All parameters to this function are deprecated and will be removed in a future release. Please use the :meth:`chariot.models.model.Model.set_inference_server_settings` method to configure the inference server. Deprecated Parameters --------------------- cpu: str Number of cpus allocated to inference server. This sets cpu requests and limits of the kubernetes pod. See https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#meaning-of-cpu for more detail. num_workers: int For artifact_type=Pytorch this value sets minWorkers, maxWorkers, default_workers_per_model to the specified value. See https://pytorch.org/serve/configuration.html for more detail. For all other artifact types, sets they MLServer parallel_workers field to the specified value. See https://mlserver.readthedocs.io/en/latest/user-guide/parallel-inference.html for more details. memory: str Amount of memory allocated to the inference server. This sets the memory requests and limits of the kubernetes pod. See https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#meaning-of-memory for more detail. min_replicas: int Minimum number of server replicas to scale down to. Defaults to 0. max_replicas: int Maximum number of server replicas to scale up to. Defaults to 1. scale_metric: str Metric to scale off of. Currently, only 'concurrency' and 'rps' are supported. Defaults to 'concurrency'. 'concurrency': number of simultaneous open http connections. 'rps': number of http requests per second averaged over 60 seconds. scale_target: int Threshold value that, once exceeded, will trigger a scale up event if max_replicas has not been reached. scale_down_delay: int | str | None The amount of time to wait after the scale_metric falls below the scale_target before scaling down in min_replicas has not been reached. Can be in integer in seconds, a string with a number followed by 's', 'm', or 'h' for seconds, minutes, or hours respectively, or None to use the default value. gpu_count: int number of gpus requested gpu_type: str Type of gpu to request. If set, gpu_count must be greater than 0 edit_existing: bool If inference server already exists, will update it to contain the setting specified in this function call quantization_bits: int Passing this parameter will trigger quantization for huggingface models, only 4 or 8 is currently supported. This is passed as a model kwargs to the inference server. huggingface_model_kwargs: dict Any parameter passed here will be passed as a model kwargs when creating a huggingface inference server. inference_engine: Optional[str] The inference_engine to use, user selectable runtimes enables models to run under a different inference_engnine than the artifact type. The model must have been convert to that runtime with models export first. Passing nothing for this will result in running as the artifact type. vllm_config: Optional[dict] The configuration for the vLLM inference engine. Only valid when `inference_engine="vLLM"`. Please consult the Chariot docs for the available options for vLLM configs. """ if ( cpu is not None or num_workers is not None or memory is not None or min_replicas is not None or max_replicas is not None or scale_metric is not None or scale_target is not None or scale_down_delay is not None or gpu_count is not None or gpu_type is not None or edit_existing is not None or quantization_bits is not None or huggingface_model_kwargs is not None or inference_engine is not None or vllm_config is not None ): warnings.warn( "All parameters to `start_inference_server` are deprecated and will be removed in a future release. " "Please use the `set_inference_server_settings` method to configure the inference server.", DeprecationWarning, stacklevel=2, ) settings = _translate_legacy_start_inference_server_params( cpu=cpu, num_workers=num_workers, memory=memory, min_replicas=min_replicas, max_replicas=max_replicas, scale_metric=scale_metric, scale_target=scale_target, scale_down_delay=scale_down_delay, gpu_count=gpu_count, gpu_type=gpu_type, quantization_bits=quantization_bits, huggingface_model_kwargs=huggingface_model_kwargs, inference_engine=inference_engine, vllm_config=vllm_config, ) # Start inference server model_id = self._meta.id assert model_id status, _ = self._get_inference_server_status_and_resp() if status is InferenceServerStatus.NULL: if settings: create_isvc_settings(model_id, settings) log.info(f"Starting inference server for {self}") _apis.models.models_api.projects_project_models_id_inferenceservices_post( project=self.project_id, id=model_id, body={} ) elif edit_existing: if settings: log.info(f"Patching inference server for {self}") create_isvc_settings(model_id, settings) else: log.info("No settings provided, skipping patching.") else: log.info( f"Inference server already exists for {self} with status {status.value}. " "To update the existing inference server use `set_inference_server_settings`." )
[docs] def stop_inference_server(self): log.info(f"Stopping inference server for {self}") try: _apis.models.models_api.projects_project_models_id_inferenceservices_delete( project=self.project_id, id=self._meta.id, _request_timeout=15, ) except ApiException as e: if e.reason == "Not Found": log.error("Inference server not found") else: log.error(e) self._internal_url = self._external_url = None self._inference_engine = None
[docs] def wait_for_inference_server( self, timeout: int, verbose: bool = False, wait_interval: int = 1, internal_url: bool = False, ) -> OutputInferenceService: """Waits for the model's dedicated inference server to be running and ready to accept requests. Will scale model up if it scaled to zero. Parameters ---------- timeout: Number in seconds to wait for inference server to spin up. verbose: Whether to enable more verbose logging. wait_interval: How many seconds to wait after each query of the get inference server status endpoint before trying again. internal_url: Set to `True` to use inference within cluster. Returns ------- The OutputInferenceService object for the inference service if it exists, None otherwise. Raises ------ ApiException: If the call to get inference server status does not return a status code that is 2XX or 404. RuntimeError: If the inference server failed to spin up or was not able to spin up within the timeout period. """ start_time = time.time() first_iter = True while True: # _get_inference_server_status_and_resp will raise a ApiException if status code not 2XX or 404 status, resp = self._get_inference_server_status_and_resp() log.debug(f"Status of inference server {status}") if resp is not None: match status: case InferenceServerStatus.READY | InferenceServerStatus.UPDATING if resp.ready: break case InferenceServerStatus.ERROR: raise RuntimeError( f"Inference server for {self} failed with status {status}." ) case InferenceServerStatus.SCALED_DOWN: _apis.models.models_api.models_id_isvc_keepalive_get(self.id) if time.time() - start_time > timeout: raise RuntimeError( f"Inference server for {self} failed to be ready within {timeout} seconds." ) if verbose and first_iter: log.warning( f"Inference server for {self} is not ready yet, " f"will wait up to {timeout} seconds for it to spin up." ) first_iter = False time.sleep(wait_interval) log.info( f"Inference server for {self} has status {status} and readiness value equal to {resp.ready}" ) self._use_internal_url = internal_url self._internal_url = resp.internal_url self._external_url = resp.external_url self._inference_engine = resp.inference_engine return resp
[docs] def infer( self, action: Action, sample: Any, timeout: int = 60, verbose: bool = False, url: str | None = None, custom_metadata: CustomMetadata | Sequence[CustomMetadata] | None = None, return_inference_id: bool = False, return_semantic_score: bool = False, score_threshold: float | None = None, **inference_kwargs, ) -> Any: """Run inference `action` on `sample`. This method posts data to the model's inference server and returns the results. The `actions` property lists the available actions for this model. The inference response id is returned when `return_inference_id` is true. An inference request may or may not be batched, but it must contain at least one input. As such, if inference storage is enabled, a small modification to the returned id is necessary. The lookup pattern within the inference store is `id-#` where `#` represents the index of the inference request input. For example, if an inference request with a batch of two inputs is provided, appending `-0` and `-1` to the id to get each inference from the inference-store will be required. """ if action not in self.actions: raise ActionUnsupportedByCurrentModelError( f"This model's task type '{self.task}' does not supports the '{action}' action, " f" please use one of the following actions: {self.actions}" ) # TODO MOD-321 ONNX model inference with SDK if self._artifact_type == ArtifactType.ONNX: log.warning( "Attention: The Chariot SDK currently provides limited support for ONNX model predictions. " "If you encounter any issues during this operation, it's recommended to attempt the request directly through the API." ) if self._inference_engine is None or (url is None and self.inference_url is None): self.wait_for_inference_server(timeout, verbose, internal_url=self._use_internal_url) is_openai_api = self._inference_engine == InferenceEngine.VLLM.value if not url: url = self.inference_url inputs = _kserve.create_inputs( self._meta.id or "", self._artifact_type, self.task, sample, action, custom_metadata, score_threshold=score_threshold, is_openai_api=is_openai_api, **inference_kwargs, ) if is_openai_api: payload = inputs else: payload = {"inputs": inputs} data = _post_request(url, payload, self._artifact_type) return _kserve.handle_output( model_name=self._meta.name or "", artifact_type=self._artifact_type, task_type=self.task, action=action, inverse_class_labels=self.inverse_class_labels, return_enriched_inference=return_inference_id or return_semantic_score, inputs=inputs, output=data, )
[docs] def exports_supported(self): """Return supported export modes for this model""" result = _apis.modelexport.default_api.export_types_models_v1_projects_project_id_models_model_id_supported_exports_get( project_id=self.project_id, model_id=self.id ) return result["data"]
[docs] def export_onnx_model(self, tarfile: str): """Download ONNX representation of the model into a tarfile""" if not tarfile.endswith("tgz") and not tarfile.endswith(".tar.gz"): raise ValueError(f'Tarfile "{tarfile}" should end with tgz or tar.gz') with open(tarfile, "wb") as f: f.write( _apis.modelexport.default_api.download_onnx_models_v1_projects_project_id_models_model_id_onnx_get_without_preload_content( project_id=self.project_id, model_id=self.id, ).data ) log.info(f"saved ONNX model to {tarfile}")
[docs] def download_model(self, tarfile: str): """Download the model as a tar.gz file""" if not tarfile.endswith("tgz") and not tarfile.endswith(".tar.gz"): raise ValueError(f'Tarfile "{tarfile}" should end with tgz or tar.gz') with open(tarfile, "wb") as f: f.write( _apis.models.models_api.models_id_download_get_without_preload_content( id=self.id, _request_timeout=300, ).data ) log.info(f"saved model to {tarfile}")
[docs] def files(self): """Recursive listing of all files for model, returns [{ last_modified, name, size }]""" return [ { "last_modified": entry.last_modified, "name": entry.name, "size": entry.size, } for entry in _apis.models.models_api.models_id_files_get( id=self.id, _request_timeout=300, recursive=True ).files ]
[docs] def supported_and_existing_inference_engines(self): return _apis.modelexport.default_api.convert_models_v1_model_id_convert_get( model_id=self.id, ).data
[docs] def convert_inference_engine(self, inference_engine, force_overwrite=False) -> tuple[int, str]: get_result = _apis.modelexport.default_api.convert_models_v1_model_id_convert_get( model_id=self.id, ).data if inference_engine not in get_result.eligible_runtime_conversions: return ( HTTPStatus.BAD_REQUEST, "model {self.id} cannot be converted to {runtime}, suppported conversions are {get_result.eligible_runtime_conversions}", ) response = _apis.modelexport.default_api.convert_models_v1_model_id_convert_post( model_id=self.id, target_framework=inference_engine, force_overwrite=force_overwrite, ) return response.code, response.message
[docs] @_apis.login_required def fork( self, project_id: str, *_, name: str | None = None, summary: str | None = None, version: str | None = None, ) -> "Model": """Fork the model. Parameters ---------- project_id : str The project to fork the model into. name : str, optional Optional name override. summary : str, optional Optional summary override. version : str, optional Optional model version override. Returns ------- Model The new model fork. """ # TODO (c.zaloom) - have to nest this in here because importing from upload.py causing circular import errors :( from chariot.models.upload import _create_model, _generate_create_chariot_model_payload payload = _generate_create_chariot_model_payload( summary=summary, name=name, version=version, task_type=self.task, artifact_type=self._artifact_type, model_conf=self._meta.pymodel_config, fork_model_id=self.id, ) return _create_model(project_id, payload)
[docs] @_apis.login_required def get_catalog(project_id: str, **kwargs) -> list[OutputModelSummary]: """get_catalog returns the model catalog matching the supplied keyword filters. See Chariot REST documentation for details. Params ------ project_id : str project_id for models query """ return _apis.models.models_api.catalog_get(project_ids=project_id, **kwargs).data
[docs] @_apis.login_required def get_models( project_id: str | None = None, **kwargs, ) -> list[Model]: """get_models returns all models matching the supplied keyword filters. See Chariot REST documentation for details. """ return [ Model(metadata=m, start_server=False) for m in _apis.models.models_api.models_get(project_ids=project_id, **kwargs).data ]
[docs] @_apis.login_required def iter_models( **kwargs, ): """iter_models returns an iterator over all models matching the supplied filters. See Chariot REST documentation for details. """ if "sort" in kwargs: raise ValueError("the `sort` parameter is reserved by iter_models") if "after" in kwargs: raise ValueError("the `after` parameter is reserved by iter_models") after_id = None while True: models = [ Model(metadata=m, start_server=False) for m in _apis.models.models_api.models_get( sort="id:asc", after=after_id, **kwargs ).data ] if not models: break after_id = models[-1].id yield from models
[docs] @_apis.login_required def get_model_by_id(id: str) -> Model: """get_model_by_id returns the model matching the supplied id.""" models = _apis.models.models_api.models_get( id=id, ).data if not models or id is None: raise ModelDoesNotExistError(id) return Model(metadata=models[0], start_server=False)
def _get_llm_configurations( huggingface_model_kwargs: dict | None = None, vllm_config: dict | None = None, quantization_bits: int | None = None, inference_engine: InferenceEngine | str | None = None, ): if huggingface_model_kwargs and vllm_config: raise ValueError( "Cannot specify both a huggingface_model_kwargs and a vllm_config, as Huggingface and vLLM are separate runtimes." ) if inference_engine == InferenceEngine.VLLM and huggingface_model_kwargs: raise ValueError( "If using the vLLM inferene engine, please pass a vllm_config instead of huggingface_model_kwargs" ) if inference_engine != InferenceEngine.VLLM and vllm_config: raise ValueError("If using a vllm_config, please pass `inference_engine='vLLM'`") if inference_engine == InferenceEngine.VLLM and quantization_bits: vllm_config = vllm_config or {} if quantization_bits == 4: vllm_config["bits_and_bytes_4bit"] = True else: raise ValueError("Only 4-bit quantization is allowed for vLLM") elif quantization_bits: huggingface_model_kwargs = huggingface_model_kwargs or {} if quantization_bits == 4: huggingface_model_kwargs["load_in_4bit"] = True elif quantization_bits == 8: huggingface_model_kwargs["load_in_8bit"] = True else: raise ValueError(f"quantization bits can only be 8 or 4, got: {quantization_bits}") return huggingface_model_kwargs, vllm_config def _translate_legacy_start_inference_server_params( cpu: str | None = None, num_workers: int | None = None, memory: str | None = None, min_replicas: int | None = None, max_replicas: int | None = None, scale_metric: str | None = None, scale_target: int | None = None, scale_down_delay: int | str | None = None, gpu_count: int | None = None, gpu_type: str | None = None, quantization_bits: int | None = None, huggingface_model_kwargs: dict | None = None, inference_engine: InferenceEngine | str | None = None, vllm_config: dict | None = None, ): # Translate legacy parameters to settings dict huggingface_model_kwargs, vllm_config = _get_llm_configurations( huggingface_model_kwargs=huggingface_model_kwargs, vllm_config=vllm_config, quantization_bits=quantization_bits, inference_engine=inference_engine, ) translated_settings = {} # predictor_cpu: k8s_quantity > 0 (default: 1) if cpu is None: pass elif isinstance(cpu, int): translated_settings["predictor_cpu"] = str(cpu) elif isinstance(cpu, str): translated_settings["predictor_cpu"] = cpu else: raise ValueError("cpu must be an integer or a string") # predictor_memory: k8s_quantity > 0 (default: 4Gi) if memory is None: pass elif isinstance(memory, str): translated_settings["predictor_memory"] = memory else: raise ValueError("memory must be a string") # predictor_gpu: null | {product: string, count: int} (default: null) if bool(gpu_type) ^ bool(gpu_count): raise ValueError( "gpu_count and gpu_type must be specified together with gpu_count being greater than 0" ) if gpu_type and gpu_count: translated_settings["predictor_gpu"] = {"product": gpu_type, "count": gpu_count} # predictor_min_replicas: int >=0, <=ReplicaLimit (default: 0) if min_replicas is None: pass elif isinstance(min_replicas, int): translated_settings["predictor_min_replicas"] = min_replicas else: raise ValueError("min_replicas must be an integer") # predictor_max_replicas: int >=1, <=ReplicaLimit (default: 1) if max_replicas is None: pass elif isinstance(max_replicas, int): translated_settings["predictor_max_replicas"] = max_replicas else: raise ValueError("max_replicas must be an integer") # predictor_scale_metric: “concurrency” | “rps” (default: “concurrency”) if scale_metric is None: pass elif scale_metric in ["concurrency", "rps"]: translated_settings["predictor_scale_metric"] = scale_metric else: raise ValueError("scale_metric must be 'concurrency' or 'rps'") # predictor_scale_target: int > 0 (default: 5) if scale_target is None: pass elif isinstance(scale_target, int): translated_settings["predictor_scale_target"] = scale_target else: raise ValueError("scale_target must be an integer") # transformer_min_replicas: int >=0, <=ReplicaLimit (default: 0) if min_replicas is None: pass elif isinstance(min_replicas, int): translated_settings["transformer_min_replicas"] = min_replicas else: raise ValueError("min_replicas must be an integer") # transformer_max_replicas: int >=1, <=ReplicaLimit (default: 1) if max_replicas is None: pass elif isinstance(max_replicas, int): translated_settings["transformer_max_replicas"] = max_replicas else: raise ValueError("max_replicas must be an integer") # transformer_scale_metric: “concurrency” | “rps” (default: “concurrency”) if scale_metric is None: pass elif scale_metric in ["concurrency", "rps"]: translated_settings["transformer_scale_metric"] = scale_metric else: raise ValueError("scale_metric must be 'concurrency' or 'rps'") # transformer_scale_target: int > 0 (default: 20) if scale_target is None: pass elif isinstance(scale_target, int): translated_settings["transformer_scale_target"] = scale_target else: raise ValueError("scale_target must be an integer") # scale_down_delay_seconds: int >= 0, <= 3600 (default: 600) if scale_down_delay is None: pass elif isinstance(scale_down_delay, int): translated_settings["scale_down_delay_seconds"] = scale_down_delay elif isinstance(scale_down_delay, str): if len(scale_down_delay) > 0: if scale_down_delay[-1] == "s": factor = 1 elif scale_down_delay[-1] == "m": factor = 60 elif scale_down_delay[-1] == "h": factor = 3600 else: raise ValueError( "if scale_down_delay is a string, it must be an integer followed by 's', 'm' or 'h'" ) translated_settings["scale_down_delay_seconds"] = int(scale_down_delay[:-1]) * factor else: raise ValueError("scale_down_delay must be an integer or a string") # num_workers: int >=1, <= 100 (default: 1) if num_workers is None: pass elif isinstance(num_workers, int): translated_settings["num_workers"] = num_workers else: raise ValueError("num_workers must be an integer") # inference_engine: null | "ChariotPytorch" | "ChariotDeepSparse" | "vLLM" | "Huggingface" (default: null) if inference_engine is None: pass elif isinstance(inference_engine, str): translated_settings["inference_engine"] = inference_engine elif isinstance(inference_engine, InferenceEngine): translated_settings["inference_engine"] = inference_engine.value else: raise ValueError("inference_engine must be a string or an InferenceEngine") # huggingface_model_kwargs: null | {[key: string]: any} (default: null) if huggingface_model_kwargs is not None: translated_settings["huggingface_model_kwargs"] = huggingface_model_kwargs # vllm_configuration: null | { # bitsandbytes_4bit: null | boolean, # enable_prefix_caching: null | boolean, # max_model_length: null | number, # seed: null | number, # } (default: null) if vllm_config is not None: translated_settings["vllm_configuration"] = vllm_config return translated_settings