from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Literal, NotRequired, TypedDict
from chariot import _apis, mcp_setting
from chariot_api._openapi.models import models as models_api
from chariot_api._openapi.models.api.inference_engine_api import InferenceEngineApi
from chariot_api._openapi.models.exceptions import NotFoundException
__all__ = [
"EngineSelector",
"EnvVar",
"InferenceEngine",
"ReadinessProbe",
"EnvVarSchema",
"add_supported_engine",
"create_engine_version",
"delete_engine_version",
"get_engine_version_by_selector",
"get_engine_versions",
"get_engine_versions_with_count",
"get_supported_engines",
"remove_supported_engine",
"set_default_engine_version",
]
[docs]
@dataclass
class InferenceEngine:
"""An inference engine."""
engine_id: str
project_id: str
name: str
[docs]
class EnvVar(TypedDict):
"""An environment variable to be set in the engine container."""
name: str
value: NotRequired[str]
[docs]
class EnvVarSchema(TypedDict):
"""A configurable parameter for this engine."""
name: str
type: Literal["string", "float", "int", "bool"]
display_text: NotRequired[str]
default_value: NotRequired[str]
units: NotRequired[str]
required: NotRequired[bool]
[docs]
@dataclass
class ReadinessProbe:
"""Readiness probe config."""
path: str
port: int
initial_delay_seconds: int
timeout_seconds: int
period_seconds: int
success_threshold: int
failure_threshold: int
@dataclass
class InferenceEngineVersion:
"""An inference engine version."""
engine_version_id: str
engine_id: str
user_id: str
project_id: str
created_at: datetime
name: str
version: str
container_registry_secret: str
container_image_cpu: str
container_image_gpu: str
entrypoint: list[str]
command: list[str]
env: list[EnvVar]
predictor_env_schema: list[EnvVarSchema]
enforce_predictor_env_schema: bool
documentation: str
icon: str
container_root_relative_base_url: str
readiness_probe: ReadinessProbe
is_default: bool
inference_protocol: str | None
[docs]
@dataclass
class EngineSelector:
"""The data needed to specify a unique inference engine."""
org_name: str
project_name: str
engine_name: str
def _inference_engine_api() -> InferenceEngineApi:
return _apis.models.inference_engine_api # type: ignore
[docs]
@mcp_setting(mutating=True)
def create_engine_version(
name: str,
project_id: str,
version: str,
readiness_probe: ReadinessProbe,
container_registry_secret: str | None = None,
container_image_cpu: str | None = None,
container_image_gpu: str | None = None,
entrypoint: list[str] | None = None,
command: list[str] | None = None,
env: list[EnvVar] | None = None,
predictor_env_schema: list[EnvVarSchema] | None = None,
enforce_predictor_env_schema: bool = False,
documentation: str | None = None,
icon: str | None = None,
container_root_relative_base_url: str = "/",
inference_protocol: str | None = None,
is_default: bool = False,
) -> str:
"""Create an inference engine."""
openapi_envs = [models_api.EnvVar(**env_var) for env_var in env or []]
openapi_predictor_env_schema = [
models_api.EnvVarSchema(**env_schema) for env_schema in predictor_env_schema or []
]
return (
_inference_engine_api()
.inference_engines_post(
body=models_api.CreateEnginePostRequest(
name=name,
project_id=project_id,
version=version,
documentation=documentation,
icon=icon,
entrypoint=entrypoint or [],
command=command or [],
env=openapi_envs,
container_image_cpu=container_image_cpu,
container_image_gpu=container_image_gpu,
container_registry_secret=container_registry_secret,
predictor_env_schema=openapi_predictor_env_schema,
enforce_predictor_env_schema=enforce_predictor_env_schema,
container_root_relative_base_url=container_root_relative_base_url,
readiness_probe=models_api.ReadinessProbe(**asdict(readiness_probe)),
inference_protocol=inference_protocol,
is_default=is_default,
)
)
.data.id
)
[docs]
@mcp_setting(mutating=True)
def delete_engine_version(engine_version_id: str) -> str:
"""Delete an inference engine."""
return _inference_engine_api().inference_engines_version_id_delete(version_id=engine_version_id)
[docs]
def get_engine_versions_with_count(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> tuple[list[InferenceEngineVersion], int]:
"""Get inference engines together with a total count."""
response = (
_inference_engine_api().inference_engines_get_with_http_info(
project_ids=",".join(project_ids) if project_ids else None,
engine_names=",".join(engine_names) if engine_names else None,
engine_ids=",".join(engine_ids) if engine_ids else None,
engine_version_ids=",".join(engine_version_ids) if engine_version_ids else None,
default_versions_only=default_versions_only,
limit=limit,
offset=offset,
)
or []
)
if not response.headers:
count = 0
else:
count = int(
response.headers.get("X-Record-Count", response.headers.get("x-record-count", 0))
)
def make_engine_version(d):
readiness_probe_dict = d.pop("readiness_probe", {})
readiness_probe = ReadinessProbe(**readiness_probe_dict)
created_at = datetime.fromisoformat(d.pop("created_at"))
return InferenceEngineVersion(readiness_probe=readiness_probe, created_at=created_at, **d)
return [
make_engine_version({"inference_protocol": None, **api_engine_version.to_dict()})
for api_engine_version in response.data.data
], count
[docs]
def get_engine_versions(
project_ids: list[str] | None = None,
engine_names: list[str] | None = None,
engine_ids: list[str] | None = None,
engine_version_ids: list[str] | None = None,
default_versions_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> list[InferenceEngineVersion]:
"""Get inference engines."""
engine_versions, _ = get_engine_versions_with_count(
project_ids,
engine_names,
engine_ids,
engine_version_ids,
default_versions_only,
limit,
offset,
)
return engine_versions
[docs]
@mcp_setting(mutating=True)
def set_default_engine_version(engine_version_id: str) -> None:
"""Set or clear the default version for an engine."""
_inference_engine_api().inference_engines_version_id_default_put(version_id=engine_version_id)
[docs]
@mcp_setting(mutating=True)
def add_supported_engine(
model_id: str,
selector: str | EngineSelector,
) -> None:
"""Add a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_post(
model_id=model_id,
body=models_api.AddSupportedInferenceEnginePostRequest(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
[docs]
@mcp_setting(mutating=True)
def remove_supported_engine(
model_id: str,
selector: str | EngineSelector,
) -> None:
"""Remove a supported inference engine to a model."""
match selector:
case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name):
engine_id = None
case str(engine_id):
org_name, project_name, engine_name = None, None, None
_inference_engine_api().models_model_id_inference_engines_delete(
model_id=model_id,
body=models_api.RemoveSupportedInferenceEngine(
engine_id=engine_id,
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
),
)
[docs]
def get_supported_engines(model_id: str) -> list[InferenceEngine]:
"""Get the supported inference engines for a model."""
return [
InferenceEngine(**engine.to_dict())
for engine in _inference_engine_api()
.models_model_id_inference_engines_get(model_id=model_id)
.data
or []
]
[docs]
def get_engine_version_by_selector(
org_name: str, project_name: str, engine_name: str, version: str | None = None
) -> InferenceEngineVersion | None:
"""Resolve an engine selector."""
try:
engine_version = (
_inference_engine_api()
.inference_engines_selector_get(
org_name=org_name,
project_name=project_name,
engine_name=engine_name,
version=version,
)
.data
)
except NotFoundException:
return None
return InferenceEngineVersion(**engine_version.to_dict())
# For backwards compatibility:
_create_engine_version = create_engine_version
_delete_engine_version = delete_engine_version
_get_engine_versions_with_count = get_engine_versions_with_count
_get_engine_versions = get_engine_versions
_set_default_engine_version = set_default_engine_version
_add_supported_engine = add_supported_engine
_remove_supported_engine = remove_supported_engine
_get_supported_engines = get_supported_engines
_get_engine_version_by_selector = get_engine_version_by_selector