from dataclasses import dataclass
from datetime import datetime
from typing import NotRequired, TypedDict
from chariot import _apis
from chariot_api._openapi.models import models as models_api
from chariot_api._openapi.models.api.inference_engine_api import InferenceEngineApi
[docs]
@dataclass
class InferenceEngine:
"""An inference engine."""
engine_id: str
project_id: str
name: str
[docs]
class EnvVar(TypedDict):
"""An environment variable to be set in the engine container."""
name: str
value: NotRequired[str]
[docs]
class ServingArg(TypedDict):
"""A configurable parameter for this engine."""
name: str
type: str
display_text: NotRequired[str]
default_value: NotRequired[str]
[docs]
class InferenceKwarg(TypedDict):
"""A parameter to send to this engine's inference method."""
name: str
type: str
description: NotRequired[str]
[docs]
@dataclass
class InferenceEngineVersion:
"""An inference engine version."""
engine_version_id: str
engine_id: str
user_id: str
project_id: str
created_at: datetime
name: str
version: str
container_registry_secret: str
container_image_cpu: str
container_image_gpu: str
entrypoint: list[str]
command: list[str]
env: list[EnvVar]
serving_args: list[ServingArg]
serving_kwargs_supported: bool
inference_kwargs: list[InferenceKwarg]
documentation: str
icon: str
is_default: bool
[docs]
@dataclass
class EngineSelector:
"""The data needed to specify a unique inference engine."""
org_name: str
project_name: str
engine_name: str
def _inference_engine_api() -> InferenceEngineApi:
return _apis.models.inference_engine_api # type: ignore
def _create_engine_version(
name: str,
project_id: str,
version: str,
container_registry_secret: str | None = None,
container_image_cpu: str | None = None,
container_image_gpu: str | None = None,
entrypoint: list[str] | None = None,
command: list[str] | None = None,
env: list[EnvVar] | None = None,
serving_args: list[ServingArg] | None = None,
serving_kwargs_supported: bool = False,
inference_kwargs: list[InferenceKwarg] | None = None,
documentation: str | None = None,
icon: str | None = None,
is_default: bool = False,
) -> str:
"""Create an inference engine."""
openapi_envs = [models_api.EnvVar(**env_var) for env_var in env or []]
openapi_serving_args = [
models_api.ServingArg(**serving_arg) for serving_arg in serving_args or []
]
openapi_inference_kwargs = [
models_api.InferenceKwarg(**inference_kwarg) for inference_kwarg in inference_kwargs or []
]
return (
_inference_engine_api()
.inference_engines_post(
body=models_api.CreateEnginePostRequest(
name=name,
project_id=project_id,
version=version,
documentation=documentation,
icon=icon,
entrypoint=entrypoint or [],
command=command or [],
env=openapi_envs,
container_image_cpu=container_image_cpu,
container_image_gpu=container_image_gpu,
container_registry_secret=container_registry_secret,
inference_kwargs=openapi_inference_kwargs,
serving_args=openapi_serving_args,
serving_kwargs_supported=serving_kwargs_supported,
is_default=is_default,
)
)
.data.id
)
def _delete_engine_version(engine_version_id: str) -> str:
"""Delete an inference engine."""
return _inference_engine_api().inference_engines_version_id_delete(version_id=engine_version_id)
def _get_engine_versions_with_count(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[InferenceEngineVersion], int]:
"""Get inference engines together with a total count."""
response = (
_inference_engine_api().inference_engines_get_with_http_info(
project_ids=",".join(project_ids) if project_ids else None,
engine_names=",".join(engine_names) if engine_names else None,
engine_ids=",".join(engine_ids) if engine_ids else None,
engine_version_ids=",".join(engine_version_ids) if engine_version_ids else None,
default_versions_only=default_versions_only,
limit=limit,
offset=offset,
)
or []
)
if not response.headers:
count = 0
else:
count = int(response.headers.get("X-Record-Count", 0))
return [
InferenceEngineVersion(**api_engine_version.to_dict())
for api_engine_version in response.data.data
], count
def _get_engine_versions(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
):
"""Get inference engines."""
engine_versions, _ = _get_engine_versions_with_count(
project_ids,
engine_names,
engine_ids,
engine_version_ids,
default_versions_only,
limit,
offset,
)
return engine_versions
def _set_default_engine_version(engine_version_id: str):
"""Set or clear the default version for an engine."""
_inference_engine_api().inference_engines_version_id_default_put(version_id=engine_version_id)
def _add_supported_engine(
model_id: str,
selector: str | EngineSelector,
):
"""Add a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_post(
model_id=model_id,
body=models_api.AddSupportedInferenceEnginePostRequest(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
def _remove_supported_engine(
model_id: str,
selector: str | EngineSelector,
):
"""Remove a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_delete(
model_id=model_id,
body=models_api.RemoveSupportedInferenceEngine(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
def _get_supported_engines(model_id: str) -> list[InferenceEngine]:
"""Get the supported inference engines for a model."""
return [
InferenceEngine(**engine.to_dict())
for engine in _inference_engine_api()
.models_model_id_inference_engines_get(model_id=model_id)
.data
or []
]