from dataclasses import asdict, dataclass
from datetime import datetime
from typing import NotRequired, TypedDict
from chariot import _apis
from chariot import mcp_setting
from chariot_api._openapi.models import models as models_api
from chariot_api._openapi.models.api.inference_engine_api import InferenceEngineApi
from chariot_api._openapi.models.exceptions import NotFoundException
__all__ = [
"EngineSelector",
"EnvVar",
"InferenceEngine",
"InferenceKwarg",
"ReadinessProbe",
"ServingArg",
"add_supported_engine",
"create_engine_version",
"delete_engine_version",
"get_engine_version_by_selector",
"get_engine_versions",
"get_engine_versions_with_count",
"get_supported_engines",
"remove_supported_engine",
"set_default_engine_version",
]
[docs]
@dataclass
class InferenceEngine:
"""An inference engine."""
engine_id: str
project_id: str
name: str
[docs]
class EnvVar(TypedDict):
"""An environment variable to be set in the engine container."""
name: str
value: NotRequired[str]
[docs]
class ServingArg(TypedDict):
"""A configurable parameter for this engine."""
name: str
type: str
display_text: NotRequired[str]
default_value: NotRequired[str]
[docs]
class InferenceKwarg(TypedDict):
"""A parameter to send to this engine's inference method."""
name: str
type: str
description: NotRequired[str]
[docs]
@dataclass
class ReadinessProbe:
"""Readiness probe config."""
path: str
port: int
initial_delay_seconds: int
timeout_seconds: int
period_seconds: int
success_threshold: int
failure_threshold: int
@dataclass
class InferenceEngineVersion:
"""An inference engine version."""
engine_version_id: str
engine_id: str
user_id: str
project_id: str
created_at: datetime
name: str
version: str
container_registry_secret: str
container_image_cpu: str
container_image_gpu: str
entrypoint: list[str]
command: list[str]
env: list[EnvVar]
serving_args: list[ServingArg]
serving_kwargs_supported: bool
inference_kwargs: list[InferenceKwarg]
documentation: str
icon: str
container_root_relative_base_url: str
readiness_probe: ReadinessProbe
is_default: bool
[docs]
@dataclass
class EngineSelector:
"""The data needed to specify a unique inference engine."""
org_name: str
project_name: str
engine_name: str
def _inference_engine_api() -> InferenceEngineApi:
return _apis.models.inference_engine_api # type: ignore
[docs]
@mcp_setting(mutating=True)
def create_engine_version(
name: str,
project_id: str,
version: str,
readiness_probe: ReadinessProbe,
container_registry_secret: str | None = None,
container_image_cpu: str | None = None,
container_image_gpu: str | None = None,
entrypoint: list[str] | None = None,
command: list[str] | None = None,
env: list[EnvVar] | None = None,
serving_args: list[ServingArg] | None = None,
serving_kwargs_supported: bool = False,
inference_kwargs: list[InferenceKwarg] | None = None,
documentation: str | None = None,
icon: str | None = None,
container_root_relative_base_url: str = "/",
is_default: bool = False,
) -> str:
"""Create an inference engine."""
openapi_envs = [models_api.EnvVar(**env_var) for env_var in env or []]
openapi_serving_args = [
models_api.ServingArg(**serving_arg) for serving_arg in serving_args or []
]
openapi_inference_kwargs = [
models_api.InferenceKwarg(**inference_kwarg) for inference_kwarg in inference_kwargs or []
]
return (
_inference_engine_api()
.inference_engines_post(
body=models_api.CreateEnginePostRequest(
name=name,
project_id=project_id,
version=version,
documentation=documentation,
icon=icon,
entrypoint=entrypoint or [],
command=command or [],
env=openapi_envs,
container_image_cpu=container_image_cpu,
container_image_gpu=container_image_gpu,
container_registry_secret=container_registry_secret,
inference_kwargs=openapi_inference_kwargs,
serving_args=openapi_serving_args,
serving_kwargs_supported=serving_kwargs_supported,
container_root_relative_base_url=container_root_relative_base_url,
readiness_probe=models_api.ReadinessProbe(**asdict(readiness_probe)),
is_default=is_default,
)
)
.data.id
)
[docs]
@mcp_setting(mutating=True)
def delete_engine_version(engine_version_id: str) -> str:
"""Delete an inference engine."""
return _inference_engine_api().inference_engines_version_id_delete(version_id=engine_version_id)
[docs]
def get_engine_versions_with_count(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[InferenceEngineVersion], int]:
"""Get inference engines together with a total count."""
response = (
_inference_engine_api().inference_engines_get_with_http_info(
project_ids=",".join(project_ids) if project_ids else None,
engine_names=",".join(engine_names) if engine_names else None,
engine_ids=",".join(engine_ids) if engine_ids else None,
engine_version_ids=",".join(engine_version_ids) if engine_version_ids else None,
default_versions_only=default_versions_only,
limit=limit,
offset=offset,
)
or []
)
if not response.headers:
count = 0
else:
count = int(response.headers.get("X-Record-Count", 0))
def make_engine_version(d):
readiness_probe_dict = d.pop("readiness_probe", {})
readiness_probe = ReadinessProbe(**readiness_probe_dict)
return InferenceEngineVersion(readiness_probe=readiness_probe, **d)
return [
make_engine_version(api_engine_version.to_dict())
for api_engine_version in response.data.data
], count
[docs]
def get_engine_versions(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
):
"""Get inference engines."""
engine_versions, _ = get_engine_versions_with_count(
project_ids,
engine_names,
engine_ids,
engine_version_ids,
default_versions_only,
limit,
offset,
)
return engine_versions
[docs]
@mcp_setting(mutating=True)
def set_default_engine_version(engine_version_id: str):
"""Set or clear the default version for an engine."""
_inference_engine_api().inference_engines_version_id_default_put(version_id=engine_version_id)
[docs]
@mcp_setting(mutating=True)
def add_supported_engine(
model_id: str,
selector: str | EngineSelector,
):
"""Add a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_post(
model_id=model_id,
body=models_api.AddSupportedInferenceEnginePostRequest(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
[docs]
@mcp_setting(mutating=True)
def remove_supported_engine(
model_id: str,
selector: str | EngineSelector,
):
"""Remove a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_delete(
model_id=model_id,
body=models_api.RemoveSupportedInferenceEngine(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
[docs]
def get_supported_engines(model_id: str) -> list[InferenceEngine]:
"""Get the supported inference engines for a model."""
return [
InferenceEngine(**engine.to_dict())
for engine in _inference_engine_api()
.models_model_id_inference_engines_get(model_id=model_id)
.data
or []
]
[docs]
def get_engine_version_by_selector(
org_name: str, project_name: str, engine_name: str, version: str | None = None
) -> InferenceEngineVersion | None:
"""Resolve an engine selector."""
try:
engine_version = (
_inference_engine_api()
.inference_engines_selector_get(
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
version=version,
)
.data
)
except NotFoundException:
return None
return InferenceEngineVersion(**engine_version.to_dict())
# For backwards compatibility:
_create_engine_version = create_engine_version
_delete_engine_version = delete_engine_version
_get_engine_versions_with_count = get_engine_versions_with_count
_get_engine_versions = get_engine_versions
_set_default_engine_version = set_default_engine_version
_add_supported_engine = add_supported_engine
_remove_supported_engine = remove_supported_engine
_get_supported_engines = get_supported_engines
_get_engine_version_by_selector = get_engine_version_by_selector