import gzip
import json
import time
import warnings
from collections.abc import Sequence
from datetime import UTC, datetime
from enum import StrEnum
from http import HTTPStatus
from io import BytesIO
from typing import Any, Self
import requests
from deprecated import deprecated
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from chariot import _apis
from chariot.config import get_bearer_token, getLogger, settings
from chariot.inference_store.metadata import CustomMetadata
from chariot.models import _kserve
from chariot.models.enum import (
ArtifactType,
InferenceEngine,
TaskType,
TaskTypesInferenceMethod,
)
from chariot.models.inference import ACTION_DOCSTRINGS, Action
from chariot.models.isvc_settings import (
GPUDict,
InferenceServerSettingsDict,
VLLMConfigurationDict,
create_isvc_settings,
get_inference_server_settings,
set_inference_server_settings,
)
from chariot.resources import Resource
from chariot_api._openapi.models.models.output_inference_service import (
OutputInferenceService,
)
from chariot_api._openapi.models.models.output_model import OutputModel
from chariot_api._openapi.models.models.output_model_summary import OutputModelSummary
from chariot_api._openapi.models.rest import ApiException
__all__ = [
"GPUDict",
"InferenceServerSettingsDict",
"InferenceServerStatus",
"Model",
"ModelDoesNotExistError",
"VLLMConfigurationDict",
"get_catalog",
"get_model_by_id",
"get_models",
"iter_models",
"ActionUnsupportedByCurrentModelError",
]
log = getLogger(__name__)
DEFAULT_MAX_RETRIES = 5
class identity(dict):
@staticmethod
def __missing__(key):
return key
# TODO MOD-280 these should eventually not be hard-coded (teddy knows this mapping already)
def task_from_teddy_config(model_conf: dict) -> TaskType:
class_to_task = {
"TorchvisionDetector": TaskType.OBJECT_DETECTION,
"CenterNetDetector": TaskType.OBJECT_DETECTION,
"SimpleClassImageModel": TaskType.IMAGE_CLASSIFICATION,
"EmbedModel": TaskType.IMAGE_EMBEDDING,
"TorchvisionSemanticSegmentationModel": TaskType.IMAGE_SEGMENTATION,
"SemanticSegmentationModel": TaskType.IMAGE_SEGMENTATION,
"ImageClassificationModel": TaskType.IMAGE_CLASSIFICATION,
"RetinaFaceDetector": TaskType.OBJECT_DETECTION,
"VariationalAutoEncoder": TaskType.IMAGE_AUTOENCODER,
"ClfTransformersModel": TaskType.TEXT_CLASSIFICATION,
"NERTransformersModel": TaskType.TOKEN_CLASSIFICATION,
"YOLO": TaskType.OBJECT_DETECTION,
}
return class_to_task[model_conf["#class_name"].split(".")[-1]]
# Values come directly from models-catalog.
# These should be kept in sync manually until we find a better way to handle this.
# https://github.com/Striveworks/chariot/blob/main/go/apps/models-catalog/pkg/constants/inference_server.go#L4-L11
[docs]
class InferenceServerStatus(StrEnum):
"""The status of an inference server."""
NULL = "null" # means no inference server has been requested. This one does not exist in models-catalog.
READY = "Ready"
ERROR = "Error"
SCALED_DOWN = "Scaled down"
STARTING = "Starting"
UPDATING = "Updating"
NOTINITIALIZED = "Not initialized"
PENDING = "Pending"
UNKNOWN = "Unknown"
MISCONFIGURED = "Misconfigured"
[docs]
class ActionUnsupportedByCurrentModelError(Exception):
pass
def create_inference_method(cls, action: str):
"""Curry of Model.infer. Used to dynamically add inference methods to Model."""
def method(
sample: Any,
timeout: int = 60,
verbose: bool = False,
return_inference_id: bool = False,
**inference_kwargs,
) -> dict:
return cls.infer(
action,
sample,
timeout=timeout,
verbose=verbose,
return_inference_id=return_inference_id,
**inference_kwargs,
)
return method
def _get_session(**kwargs):
adapter = HTTPAdapter(
max_retries=Retry(
total=kwargs.get("max_retries", DEFAULT_MAX_RETRIES),
connect=kwargs.get("max_retries", DEFAULT_MAX_RETRIES),
read=kwargs.get("max_retries", DEFAULT_MAX_RETRIES),
allowed_methods=None,
backoff_factor=1.5,
)
)
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
session.cookies.set("access_token", get_bearer_token())
return session
def _post_request(url, payload, artifact_type: ArtifactType, **kwargs):
try:
session = _get_session(**kwargs)
preds = session.post(
url,
verify=settings.verify_ssl,
json=payload,
)
preds.raise_for_status()
if artifact_type == ArtifactType.HUGGINGFACE:
try:
buf = BytesIO(preds.content)
gzip_f = gzip.GzipFile(fileobj=buf)
content = gzip_f.read()
return json.loads(content.decode())
# TODO clean up when moving to teddy doing inference.
# Right now it appers only text-classification and small token-classification don't gzip.
except gzip.BadGzipFile:
return preds.json()
else:
return preds.json()
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Inference request failed: {e} {e.response.text}")
[docs]
class ModelDoesNotExistError(Exception):
def __init__(self, model_name: str, version: str | None = None):
super().__init__(
f"Model {model_name!r}{(' version ' + version) if version else ''} does not exist or you do not have access."
)
self.model_name = model_name
self.version = version
class ModelUploadTimeoutError(Exception):
pass
[docs]
class Model(Resource):
task_to_method = {
T.value: getattr(TaskTypesInferenceMethod, T.name).value
for T in TaskType
if getattr(TaskTypesInferenceMethod, T.name).value
}
default_action_methods = ["predict"]
def __init__(
self,
project_id: str | None = None,
id: str | None = None,
project_name: str | None = None,
subproject_name: str | None = None,
organization_id: str | None = None,
name: str | None = None,
version: str | None = None,
metadata: Any = None,
start_server: bool = True,
):
super().__init__(
project_id=project_id,
id=id,
project_name=project_name,
subproject_name=subproject_name,
organization_id=organization_id,
name=name,
version=version,
metadata=metadata,
)
model_config = self._meta.pymodel_config
self._artifact_type = ArtifactType.get(self._meta.artifact_type or "")
if not self._meta.task_type and self._artifact_type == ArtifactType.CHARIOT:
self.task = task_from_teddy_config(model_config)
else:
self.task = TaskType.get(self._meta.task_type or "")
if start_server:
self.start_inference_server()
self._populate_class_inference_methods() # auto-populate inference methods
self._isvc = None
self._use_internal_url = False # whether to cluster DNS address
self._internal_url = None
self._external_url = None
self._inference_engine = None
@property
def inference_url(self) -> str | None:
"""Url to inference server inference endpoint."""
return self._internal_url if self._use_internal_url else self._external_url
def _populate_class_inference_methods(self):
"""Set the available inference methods as class methods."""
if self.actions is None:
log.error("Unable to set inference methods")
for action in self.actions:
method = create_inference_method(self, action)
method.__doc__ = ACTION_DOCSTRINGS.get(action, f"Run `{action}` on the model")
method.__name__ = action
setattr(self, method.__name__, method)
@staticmethod
def _get_resource_meta(project_id: str, id: str) -> OutputModel:
data = _apis.models.models_api.models_get(project_ids=project_id, id=id).data
if len(data) == 0:
raise ModelDoesNotExistError(project_id, id)
return data[0]
@staticmethod
def _get_id_by_name_and_version(project_id: str, name: str, version: str) -> str:
kwargs = {"project_ids": project_id, "model_name": name}
if version is not None:
kwargs["model_version"] = version
models = _apis.models.models_api.models_get(**kwargs).data
if len(models) == 0:
raise ModelDoesNotExistError(name, version)
return models[0].id
@property
def status(self):
status, _ = self._get_inference_server_status_and_resp()
if status is not InferenceServerStatus.READY:
log.debug(f"Inference service is not ready yet for {self}")
else:
log.debug(
f"Inference server already exists for {self} and is currently in status {status.value}."
)
return status
@property
def architecture(self):
return self._meta.architecture
@property
def name(self):
return self._meta.name
@property
def name_slug(self):
return self._meta.name_slug
@property
def version(self):
return self._meta.version
@property
def created_at(self) -> datetime:
return datetime.fromtimestamp(int(self._meta.created_at) / 1000, tz=UTC)
@property
def storage_status(self):
return self._meta.storage_status
@property
def actions(self) -> list[str]:
return self._meta.inference_methods or self.default_action_methods
def __str__(self):
return f"Model[{self.name} v{self.version} {self.id}]"
@property
def class_labels(self):
if self._meta.class_labels is None or len(self._meta.class_labels) == 0:
return None
return self._meta.class_labels
@property
def inverse_class_labels(self):
icl = identity()
if self.class_labels is None or not isinstance(self.class_labels, dict):
return icl
icl.update({v: k for k, v in self.class_labels.items()})
return icl
[docs]
def wait_for_upload(self, timeout=60, wait_interval=2) -> Self:
"""Wait for `timeout` seconds for this model to be uploaded.
Parameters
----------
timeout:
Number in seconds to wait for model upload.
wait_interval:
How many seconds to wait after each query before trying again.
Returns
-------
The Model object, for chaining
Raises
------
ModelUploadTimeoutError:
If `storage_status` is not "uploaded" before `timeout`.
"""
storage_status = self._meta.storage_status
t_start = time.monotonic()
while True:
if storage_status == "uploaded":
return self
if storage_status == "failed":
raise ModelUploadTimeoutError(
f"Model {self.id} has storage status 'failed', expected 'uploaded'"
)
if time.monotonic() > t_start + timeout:
raise ModelUploadTimeoutError(
f"Model {self.id} has storage status '{storage_status}' after {timeout}s, expected 'uploaded'"
)
time.sleep(wait_interval)
self._meta = self._get_resource_meta(project_id=self.project_id, id=self.id)
storage_status = self._meta.storage_status
[docs]
def delete(self, hard_delete: bool = False):
"""Delete this model."""
_apis.models.models_api.projects_project_models_id_delete(
project=self.project_id, id=self._meta.id, hard_delete=hard_delete
)
log.info(f"Deleted {self}")
@property
@deprecated(
reason="The `isvc_settings` property is deprecated and will be removed in a future release. "
"Please use the `get_inference_server_settings` method instead."
)
def isvc_settings(self):
"""Get the current settings for this model's inference server.
NOTE: This property is deprecated and will be removed in a future release.
Please use the :meth:`~chariot.models.model.Model.get_inference_server_settings` method instead.
"""
return self.get_inference_server_settings()
[docs]
@deprecated(
reason="The update_isvc_settings method is deprecated and will be removed in a future release. "
"Please use the `set_inference_server_settings` method instead."
)
def update_isvc_settings(self, settings):
"""Updatesettings for this model's inference server.
NOTE: This method is deprecated and will be removed in a future release.
Please use the :meth:`~chariot.models.model.Model.set_inference_server_settings` method instead.
"""
self.set_inference_server_settings(settings)
# add property for status of inference server
def _get_inference_server_status_and_resp(
self,
) -> tuple[InferenceServerStatus, OutputInferenceService | None]:
# Returns an aggregate status of all inference sever pods
try:
resp = _apis.models.models_api.projects_project_models_id_inferenceservices_get(
project=self.project_id, id=self._meta.id
).data
except ApiException as e:
if e.status == 404:
return InferenceServerStatus.NULL, None
raise e
self._isvc = resp
return InferenceServerStatus(resp.inference_server_status), resp
[docs]
def get_inference_server_settings(self) -> InferenceServerSettingsDict:
"""Get the current settings for this model's inference server.
Returns
-------
:class:`chariot.models.model.InferenceServerSettingsDict`
"""
model_id = self._meta.id
assert model_id
return get_inference_server_settings(model_id)
[docs]
def set_inference_server_settings(self, settings: InferenceServerSettingsDict):
"""Set settings for this model's inference server.
Parameters
----------
settings: :class:`chariot.models.model.InferenceServerSettingsDict`
Settings to apply to the inference server.
"""
model_id = self._meta.id
assert model_id
set_inference_server_settings(model_id, settings)
[docs]
def start_inference_server(
self,
# Deprecated parameters:
cpu: str | None = None,
num_workers: int | None = None,
memory: str | None = None,
min_replicas: int | None = None,
max_replicas: int | None = None,
scale_metric: str | None = None,
scale_target: int | None = None,
scale_down_delay: int | str | None = None,
gpu_count: int | None = None,
gpu_type: str | None = None,
edit_existing: bool | None = None,
quantization_bits: int | None = None,
huggingface_model_kwargs: dict | None = None,
inference_engine: InferenceEngine | str | None = None,
vllm_config: dict | None = None,
):
"""Create an inference server for the model object.
NOTE: All parameters to this function are deprecated and will be removed in a future release.
Please use the :meth:`chariot.models.model.Model.set_inference_server_settings` method to configure the inference server.
Deprecated Parameters
---------------------
cpu: str
Number of cpus allocated to inference server. This sets cpu requests and limits of the kubernetes pod.
See https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#meaning-of-cpu for more detail.
num_workers: int
For artifact_type=Pytorch this value sets minWorkers, maxWorkers, default_workers_per_model to the specified value.
See https://pytorch.org/serve/configuration.html for more detail.
For all other artifact types, sets they MLServer parallel_workers field to the specified value.
See https://mlserver.readthedocs.io/en/latest/user-guide/parallel-inference.html for more details.
memory: str
Amount of memory allocated to the inference server. This sets the memory requests and limits of the kubernetes pod.
See https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#meaning-of-memory for more detail.
min_replicas: int
Minimum number of server replicas to scale down to. Defaults to 0.
max_replicas: int
Maximum number of server replicas to scale up to. Defaults to 1.
scale_metric: str
Metric to scale off of. Currently, only 'concurrency' and 'rps' are supported. Defaults to 'concurrency'.
'concurrency': number of simultaneous open http connections.
'rps': number of http requests per second averaged over 60 seconds.
scale_target: int
Threshold value that, once exceeded, will trigger a scale up event if max_replicas has not been reached.
scale_down_delay: int | str | None
The amount of time to wait after the scale_metric falls below the scale_target before scaling down in min_replicas has not been reached.
Can be in integer in seconds, a string with a number followed by 's', 'm', or 'h' for seconds, minutes, or hours respectively,
or None to use the default value.
gpu_count: int
number of gpus requested
gpu_type: str
Type of gpu to request. If set, gpu_count must be greater than 0
edit_existing: bool
If inference server already exists, will update it to contain the setting specified in this function call
quantization_bits: int
Passing this parameter will trigger quantization for huggingface models, only 4 or 8 is currently supported.
This is passed as a model kwargs to the inference server.
huggingface_model_kwargs: dict
Any parameter passed here will be passed as a model kwargs when creating a huggingface inference server.
inference_engine: Optional[str]
The inference_engine to use, user selectable runtimes enables models to run under a different inference_engnine than the artifact type.
The model must have been convert to that runtime with models export first. Passing nothing for this will result in
running as the artifact type.
vllm_config: Optional[dict]
The configuration for the vLLM inference engine. Only valid when `inference_engine="vLLM"`. Please consult the Chariot docs for the available options for vLLM configs.
"""
if (
cpu is not None
or num_workers is not None
or memory is not None
or min_replicas is not None
or max_replicas is not None
or scale_metric is not None
or scale_target is not None
or scale_down_delay is not None
or gpu_count is not None
or gpu_type is not None
or edit_existing is not None
or quantization_bits is not None
or huggingface_model_kwargs is not None
or inference_engine is not None
or vllm_config is not None
):
warnings.warn(
"All parameters to `start_inference_server` are deprecated and will be removed in a future release. "
"Please use the `set_inference_server_settings` method to configure the inference server.",
DeprecationWarning,
stacklevel=2,
)
settings = _translate_legacy_start_inference_server_params(
cpu=cpu,
num_workers=num_workers,
memory=memory,
min_replicas=min_replicas,
max_replicas=max_replicas,
scale_metric=scale_metric,
scale_target=scale_target,
scale_down_delay=scale_down_delay,
gpu_count=gpu_count,
gpu_type=gpu_type,
quantization_bits=quantization_bits,
huggingface_model_kwargs=huggingface_model_kwargs,
inference_engine=inference_engine,
vllm_config=vllm_config,
)
# Start inference server
model_id = self._meta.id
assert model_id
status, _ = self._get_inference_server_status_and_resp()
if status is InferenceServerStatus.NULL:
if settings:
create_isvc_settings(model_id, settings)
log.info(f"Starting inference server for {self}")
_apis.models.models_api.projects_project_models_id_inferenceservices_post(
project=self.project_id, id=model_id, body={}
)
elif edit_existing:
if settings:
log.info(f"Patching inference server for {self}")
create_isvc_settings(model_id, settings)
else:
log.info("No settings provided, skipping patching.")
else:
log.info(
f"Inference server already exists for {self} with status {status.value}. "
"To update the existing inference server use `set_inference_server_settings`."
)
[docs]
def stop_inference_server(self):
log.info(f"Stopping inference server for {self}")
try:
_apis.models.models_api.projects_project_models_id_inferenceservices_delete(
project=self.project_id,
id=self._meta.id,
_request_timeout=15,
)
except ApiException as e:
if e.reason == "Not Found":
log.error("Inference server not found")
else:
log.error(e)
self._internal_url = self._external_url = None
self._inference_engine = None
[docs]
def wait_for_inference_server(
self,
timeout: int,
verbose: bool = False,
wait_interval: int = 1,
internal_url: bool = False,
) -> OutputInferenceService:
"""Waits for the model's dedicated inference server to be running and ready to accept requests.
Will scale model up if it scaled to zero.
Parameters
----------
timeout:
Number in seconds to wait for inference server to spin up.
verbose:
Whether to enable more verbose logging.
wait_interval:
How many seconds to wait after each query of the get inference server status endpoint before trying again.
internal_url:
Set to `True` to use inference within cluster.
Returns
-------
The OutputInferenceService object for the inference service if it exists, None otherwise.
Raises
------
ApiException:
If the call to get inference server status does not return a status code that is 2XX or 404.
RuntimeError:
If the inference server failed to spin up or was not able to spin up within the timeout period.
"""
start_time = time.time()
first_iter = True
while True:
# _get_inference_server_status_and_resp will raise a ApiException if status code not 2XX or 404
status, resp = self._get_inference_server_status_and_resp()
log.debug(f"Status of inference server {status}")
if resp is not None:
match status:
case InferenceServerStatus.READY | InferenceServerStatus.UPDATING if resp.ready:
break
case InferenceServerStatus.ERROR:
raise RuntimeError(
f"Inference server for {self} failed with status {status}."
)
case InferenceServerStatus.SCALED_DOWN:
_apis.models.models_api.models_id_isvc_keepalive_get(self.id)
if time.time() - start_time > timeout:
raise RuntimeError(
f"Inference server for {self} failed to be ready within {timeout} seconds."
)
if verbose and first_iter:
log.warning(
f"Inference server for {self} is not ready yet, "
f"will wait up to {timeout} seconds for it to spin up."
)
first_iter = False
time.sleep(wait_interval)
log.info(
f"Inference server for {self} has status {status} and readiness value equal to {resp.ready}"
)
self._use_internal_url = internal_url
self._internal_url = resp.internal_url
self._external_url = resp.external_url
self._inference_engine = resp.inference_engine
return resp
[docs]
def infer(
self,
action: Action,
sample: Any,
timeout: int = 60,
verbose: bool = False,
url: str | None = None,
custom_metadata: CustomMetadata | Sequence[CustomMetadata] | None = None,
return_inference_id: bool = False,
return_semantic_score: bool = False,
score_threshold: float | None = None,
**inference_kwargs,
) -> Any:
"""Run inference `action` on `sample`.
This method posts data to the model's inference server and returns the results. The
`actions` property lists the available actions for this model.
The inference response id is returned when `return_inference_id` is true. An inference
request may or may not be batched, but it must contain at least one input. As such, if
inference storage is enabled, a small modification to the returned id is necessary. The
lookup pattern within the inference store is `id-#` where `#` represents the index of the
inference request input. For example, if an inference request with a batch of two inputs
is provided, appending `-0` and `-1` to the id to get each inference from
the inference-store will be required.
"""
if action not in self.actions:
raise ActionUnsupportedByCurrentModelError(
f"This model's task type '{self.task}' does not supports the '{action}' action, "
f" please use one of the following actions: {self.actions}"
)
# TODO MOD-321 ONNX model inference with SDK
if self._artifact_type == ArtifactType.ONNX:
log.warning(
"Attention: The Chariot SDK currently provides limited support for ONNX model predictions. "
"If you encounter any issues during this operation, it's recommended to attempt the request directly through the API."
)
if self._inference_engine is None or (url is None and self.inference_url is None):
self.wait_for_inference_server(timeout, verbose, internal_url=self._use_internal_url)
is_openai_api = self._inference_engine == InferenceEngine.VLLM.value
if not url:
url = self.inference_url
inputs = _kserve.create_inputs(
self._meta.id or "",
self._artifact_type,
self.task,
sample,
action,
custom_metadata,
score_threshold=score_threshold,
is_openai_api=is_openai_api,
**inference_kwargs,
)
if is_openai_api:
payload = inputs
else:
payload = {"inputs": inputs}
data = _post_request(url, payload, self._artifact_type)
return _kserve.handle_output(
model_name=self._meta.name or "",
artifact_type=self._artifact_type,
task_type=self.task,
action=action,
inverse_class_labels=self.inverse_class_labels,
return_enriched_inference=return_inference_id or return_semantic_score,
inputs=inputs,
output=data,
)
[docs]
def exports_supported(self):
"""Return supported export modes for this model"""
result = _apis.modelexport.default_api.export_types_models_v1_projects_project_id_models_model_id_supported_exports_get(
project_id=self.project_id, model_id=self.id
)
return result["data"]
[docs]
def export_onnx_model(self, tarfile: str):
"""Download ONNX representation of the model into a tarfile"""
if not tarfile.endswith("tgz") and not tarfile.endswith(".tar.gz"):
raise ValueError(f'Tarfile "{tarfile}" should end with tgz or tar.gz')
with open(tarfile, "wb") as f:
f.write(
_apis.modelexport.default_api.download_onnx_models_v1_projects_project_id_models_model_id_onnx_get_without_preload_content(
project_id=self.project_id,
model_id=self.id,
).data
)
log.info(f"saved ONNX model to {tarfile}")
[docs]
def download_model(self, tarfile: str):
"""Download the model as a tar.gz file"""
if not tarfile.endswith("tgz") and not tarfile.endswith(".tar.gz"):
raise ValueError(f'Tarfile "{tarfile}" should end with tgz or tar.gz')
with open(tarfile, "wb") as f:
f.write(
_apis.models.models_api.models_id_download_get_without_preload_content(
id=self.id,
_request_timeout=300,
).data
)
log.info(f"saved model to {tarfile}")
[docs]
def files(self):
"""Recursive listing of all files for model, returns [{ last_modified, name, size }]"""
return [
{
"last_modified": entry.last_modified,
"name": entry.name,
"size": entry.size,
}
for entry in _apis.models.models_api.models_id_files_get(
id=self.id, _request_timeout=300, recursive=True
).files
]
[docs]
def supported_and_existing_inference_engines(self):
return _apis.modelexport.default_api.convert_models_v1_model_id_convert_get(
model_id=self.id,
).data
[docs]
def convert_inference_engine(self, inference_engine, force_overwrite=False) -> tuple[int, str]:
get_result = _apis.modelexport.default_api.convert_models_v1_model_id_convert_get(
model_id=self.id,
).data
if inference_engine not in get_result.eligible_runtime_conversions:
return (
HTTPStatus.BAD_REQUEST,
"model {self.id} cannot be converted to {runtime}, suppported conversions are {get_result.eligible_runtime_conversions}",
)
response = _apis.modelexport.default_api.convert_models_v1_model_id_convert_post(
model_id=self.id,
target_framework=inference_engine,
force_overwrite=force_overwrite,
)
return response.code, response.message
[docs]
@_apis.login_required
def fork(
self,
project_id: str,
*_,
name: str | None = None,
summary: str | None = None,
version: str | None = None,
) -> "Model":
"""Fork the model.
Parameters
----------
project_id : str
The project to fork the model into.
name : str, optional
Optional name override.
summary : str, optional
Optional summary override.
version : str, optional
Optional model version override.
Returns
-------
Model
The new model fork.
"""
# TODO (c.zaloom) - have to nest this in here because importing from upload.py causing circular import errors :(
from chariot.models.upload import _create_model, _generate_create_chariot_model_payload
payload = _generate_create_chariot_model_payload(
summary=summary,
name=name,
version=version,
task_type=self.task,
artifact_type=self._artifact_type,
model_conf=self._meta.pymodel_config,
fork_model_id=self.id,
)
return _create_model(project_id, payload)
[docs]
@_apis.login_required
def get_catalog(project_id: str, **kwargs) -> list[OutputModelSummary]:
"""get_catalog returns the model catalog matching the supplied keyword filters.
See Chariot REST documentation for details.
Params
------
project_id : str
project_id for models query
"""
return _apis.models.models_api.catalog_get(project_ids=project_id, **kwargs).data
[docs]
@_apis.login_required
def get_models(
project_id: str | None = None,
**kwargs,
) -> list[Model]:
"""get_models returns all models matching the supplied keyword filters.
See Chariot REST documentation for details.
"""
return [
Model(metadata=m, start_server=False)
for m in _apis.models.models_api.models_get(project_ids=project_id, **kwargs).data
]
[docs]
@_apis.login_required
def iter_models(
**kwargs,
):
"""iter_models returns an iterator over all models matching the supplied filters.
See Chariot REST documentation for details.
"""
if "sort" in kwargs:
raise ValueError("the `sort` parameter is reserved by iter_models")
if "after" in kwargs:
raise ValueError("the `after` parameter is reserved by iter_models")
after_id = None
while True:
models = [
Model(metadata=m, start_server=False)
for m in _apis.models.models_api.models_get(
sort="id:asc", after=after_id, **kwargs
).data
]
if not models:
break
after_id = models[-1].id
yield from models
[docs]
@_apis.login_required
def get_model_by_id(id: str) -> Model:
"""get_model_by_id returns the model matching the supplied id."""
models = _apis.models.models_api.models_get(
id=id,
).data
if not models or id is None:
raise ModelDoesNotExistError(id)
return Model(metadata=models[0], start_server=False)
def _get_llm_configurations(
huggingface_model_kwargs: dict | None = None,
vllm_config: dict | None = None,
quantization_bits: int | None = None,
inference_engine: InferenceEngine | str | None = None,
):
if huggingface_model_kwargs and vllm_config:
raise ValueError(
"Cannot specify both a huggingface_model_kwargs and a vllm_config, as Huggingface and vLLM are separate runtimes."
)
if inference_engine == InferenceEngine.VLLM and huggingface_model_kwargs:
raise ValueError(
"If using the vLLM inferene engine, please pass a vllm_config instead of huggingface_model_kwargs"
)
if inference_engine != InferenceEngine.VLLM and vllm_config:
raise ValueError("If using a vllm_config, please pass `inference_engine='vLLM'`")
if inference_engine == InferenceEngine.VLLM and quantization_bits:
vllm_config = vllm_config or {}
if quantization_bits == 4:
vllm_config["bits_and_bytes_4bit"] = True
else:
raise ValueError("Only 4-bit quantization is allowed for vLLM")
elif quantization_bits:
huggingface_model_kwargs = huggingface_model_kwargs or {}
if quantization_bits == 4:
huggingface_model_kwargs["load_in_4bit"] = True
elif quantization_bits == 8:
huggingface_model_kwargs["load_in_8bit"] = True
else:
raise ValueError(f"quantization bits can only be 8 or 4, got: {quantization_bits}")
return huggingface_model_kwargs, vllm_config
def _translate_legacy_start_inference_server_params(
cpu: str | None = None,
num_workers: int | None = None,
memory: str | None = None,
min_replicas: int | None = None,
max_replicas: int | None = None,
scale_metric: str | None = None,
scale_target: int | None = None,
scale_down_delay: int | str | None = None,
gpu_count: int | None = None,
gpu_type: str | None = None,
quantization_bits: int | None = None,
huggingface_model_kwargs: dict | None = None,
inference_engine: InferenceEngine | str | None = None,
vllm_config: dict | None = None,
):
# Translate legacy parameters to settings dict
huggingface_model_kwargs, vllm_config = _get_llm_configurations(
huggingface_model_kwargs=huggingface_model_kwargs,
vllm_config=vllm_config,
quantization_bits=quantization_bits,
inference_engine=inference_engine,
)
translated_settings = {}
# predictor_cpu: k8s_quantity > 0 (default: 1)
if cpu is None:
pass
elif isinstance(cpu, int):
translated_settings["predictor_cpu"] = str(cpu)
elif isinstance(cpu, str):
translated_settings["predictor_cpu"] = cpu
else:
raise ValueError("cpu must be an integer or a string")
# predictor_memory: k8s_quantity > 0 (default: 4Gi)
if memory is None:
pass
elif isinstance(memory, str):
translated_settings["predictor_memory"] = memory
else:
raise ValueError("memory must be a string")
# predictor_gpu: null | {product: string, count: int} (default: null)
if bool(gpu_type) ^ bool(gpu_count):
raise ValueError(
"gpu_count and gpu_type must be specified together with gpu_count being greater than 0"
)
if gpu_type and gpu_count:
translated_settings["predictor_gpu"] = {"product": gpu_type, "count": gpu_count}
# predictor_min_replicas: int >=0, <=ReplicaLimit (default: 0)
if min_replicas is None:
pass
elif isinstance(min_replicas, int):
translated_settings["predictor_min_replicas"] = min_replicas
else:
raise ValueError("min_replicas must be an integer")
# predictor_max_replicas: int >=1, <=ReplicaLimit (default: 1)
if max_replicas is None:
pass
elif isinstance(max_replicas, int):
translated_settings["predictor_max_replicas"] = max_replicas
else:
raise ValueError("max_replicas must be an integer")
# predictor_scale_metric: “concurrency” | “rps” (default: “concurrency”)
if scale_metric is None:
pass
elif scale_metric in ["concurrency", "rps"]:
translated_settings["predictor_scale_metric"] = scale_metric
else:
raise ValueError("scale_metric must be 'concurrency' or 'rps'")
# predictor_scale_target: int > 0 (default: 5)
if scale_target is None:
pass
elif isinstance(scale_target, int):
translated_settings["predictor_scale_target"] = scale_target
else:
raise ValueError("scale_target must be an integer")
# transformer_min_replicas: int >=0, <=ReplicaLimit (default: 0)
if min_replicas is None:
pass
elif isinstance(min_replicas, int):
translated_settings["transformer_min_replicas"] = min_replicas
else:
raise ValueError("min_replicas must be an integer")
# transformer_max_replicas: int >=1, <=ReplicaLimit (default: 1)
if max_replicas is None:
pass
elif isinstance(max_replicas, int):
translated_settings["transformer_max_replicas"] = max_replicas
else:
raise ValueError("max_replicas must be an integer")
# transformer_scale_metric: “concurrency” | “rps” (default: “concurrency”)
if scale_metric is None:
pass
elif scale_metric in ["concurrency", "rps"]:
translated_settings["transformer_scale_metric"] = scale_metric
else:
raise ValueError("scale_metric must be 'concurrency' or 'rps'")
# transformer_scale_target: int > 0 (default: 20)
if scale_target is None:
pass
elif isinstance(scale_target, int):
translated_settings["transformer_scale_target"] = scale_target
else:
raise ValueError("scale_target must be an integer")
# scale_down_delay_seconds: int >= 0, <= 3600 (default: 600)
if scale_down_delay is None:
pass
elif isinstance(scale_down_delay, int):
translated_settings["scale_down_delay_seconds"] = scale_down_delay
elif isinstance(scale_down_delay, str):
if len(scale_down_delay) > 0:
if scale_down_delay[-1] == "s":
factor = 1
elif scale_down_delay[-1] == "m":
factor = 60
elif scale_down_delay[-1] == "h":
factor = 3600
else:
raise ValueError(
"if scale_down_delay is a string, it must be an integer followed by 's', 'm' or 'h'"
)
translated_settings["scale_down_delay_seconds"] = int(scale_down_delay[:-1]) * factor
else:
raise ValueError("scale_down_delay must be an integer or a string")
# num_workers: int >=1, <= 100 (default: 1)
if num_workers is None:
pass
elif isinstance(num_workers, int):
translated_settings["num_workers"] = num_workers
else:
raise ValueError("num_workers must be an integer")
# inference_engine: null | "ChariotPytorch" | "ChariotDeepSparse" | "vLLM" | "Huggingface" (default: null)
if inference_engine is None:
pass
elif isinstance(inference_engine, str):
translated_settings["inference_engine"] = inference_engine
elif isinstance(inference_engine, InferenceEngine):
translated_settings["inference_engine"] = inference_engine.value
else:
raise ValueError("inference_engine must be a string or an InferenceEngine")
# huggingface_model_kwargs: null | {[key: string]: any} (default: null)
if huggingface_model_kwargs is not None:
translated_settings["huggingface_model_kwargs"] = huggingface_model_kwargs
# vllm_configuration: null | {
# bitsandbytes_4bit: null | boolean,
# enable_prefix_caching: null | boolean,
# max_model_length: null | number,
# seed: null | number,
# } (default: null)
if vllm_config is not None:
translated_settings["vllm_configuration"] = vllm_config
return translated_settings