Source code for chariot.models.engines

from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Literal, NotRequired, TypedDict

from chariot import _apis, mcp_setting
from chariot_api._openapi.models import models as models_api
from chariot_api._openapi.models.api.inference_engine_api import InferenceEngineApi
from chariot_api._openapi.models.exceptions import NotFoundException

__all__ = [
    "EngineSelector",
    "EnvVar",
    "InferenceEngine",
    "ReadinessProbe",
    "EnvVarSchema",
    "add_supported_engine",
    "create_engine_version",
    "delete_engine_version",
    "get_engine_version_by_selector",
    "get_engine_versions",
    "get_engine_versions_with_count",
    "get_supported_engines",
    "remove_supported_engine",
    "set_default_engine_version",
]


[docs] @dataclass class InferenceEngine: """An inference engine.""" engine_id: str project_id: str name: str
[docs] class EnvVar(TypedDict): """An environment variable to be set in the engine container.""" name: str value: NotRequired[str]
[docs] class EnvVarSchema(TypedDict): """A configurable parameter for this engine.""" name: str type: Literal["string", "float", "int", "bool"] display_text: NotRequired[str] default_value: NotRequired[str] units: NotRequired[str] required: NotRequired[bool]
[docs] @dataclass class ReadinessProbe: """Readiness probe config.""" path: str port: int initial_delay_seconds: int timeout_seconds: int period_seconds: int success_threshold: int failure_threshold: int
@dataclass class InferenceEngineVersion: """An inference engine version.""" engine_version_id: str engine_id: str user_id: str project_id: str created_at: datetime name: str version: str container_registry_secret: str container_image_cpu: str container_image_gpu: str entrypoint: list[str] command: list[str] env: list[EnvVar] predictor_env_schema: list[EnvVarSchema] enforce_predictor_env_schema: bool documentation: str icon: str container_root_relative_base_url: str readiness_probe: ReadinessProbe is_default: bool inference_protocol: str | None
[docs] @dataclass class EngineSelector: """The data needed to specify a unique inference engine.""" org_name: str project_name: str engine_name: str
def _inference_engine_api() -> InferenceEngineApi: return _apis.models.inference_engine_api # type: ignore
[docs] @mcp_setting(mutating=True) def create_engine_version( name: str, project_id: str, version: str, readiness_probe: ReadinessProbe, container_registry_secret: str | None = None, container_image_cpu: str | None = None, container_image_gpu: str | None = None, entrypoint: list[str] | None = None, command: list[str] | None = None, env: list[EnvVar] | None = None, predictor_env_schema: list[EnvVarSchema] | None = None, enforce_predictor_env_schema: bool = False, documentation: str | None = None, icon: str | None = None, container_root_relative_base_url: str = "/", inference_protocol: str | None = None, is_default: bool = False, ) -> str: """Create an inference engine.""" openapi_envs = [models_api.EnvVar(**env_var) for env_var in env or []] openapi_predictor_env_schema = [ models_api.EnvVarSchema(**env_schema) for env_schema in predictor_env_schema or [] ] return ( _inference_engine_api() .inference_engines_post( body=models_api.CreateEnginePostRequest( name=name, project_id=project_id, version=version, documentation=documentation, icon=icon, entrypoint=entrypoint or [], command=command or [], env=openapi_envs, container_image_cpu=container_image_cpu, container_image_gpu=container_image_gpu, container_registry_secret=container_registry_secret, predictor_env_schema=openapi_predictor_env_schema, enforce_predictor_env_schema=enforce_predictor_env_schema, container_root_relative_base_url=container_root_relative_base_url, readiness_probe=models_api.ReadinessProbe(**asdict(readiness_probe)), inference_protocol=inference_protocol, is_default=is_default, ) ) .data.id )
[docs] @mcp_setting(mutating=True) def delete_engine_version(engine_version_id: str) -> str: """Delete an inference engine.""" return _inference_engine_api().inference_engines_version_id_delete(version_id=engine_version_id)
[docs] def get_engine_versions_with_count( project_ids: list[str] | None = None, engine_names: list[str] | None = None, engine_ids: list[str] | None = None, engine_version_ids: list[str] | None = None, default_versions_only: bool = False, limit: int = 100, offset: int = 0, ) -> tuple[list[InferenceEngineVersion], int]: """Get inference engines together with a total count.""" response = ( _inference_engine_api().inference_engines_get_with_http_info( project_ids=",".join(project_ids) if project_ids else None, engine_names=",".join(engine_names) if engine_names else None, engine_ids=",".join(engine_ids) if engine_ids else None, engine_version_ids=",".join(engine_version_ids) if engine_version_ids else None, default_versions_only=default_versions_only, limit=limit, offset=offset, ) or [] ) if not response.headers: count = 0 else: count = int( response.headers.get("X-Record-Count", response.headers.get("x-record-count", 0)) ) def make_engine_version(d): readiness_probe_dict = d.pop("readiness_probe", {}) readiness_probe = ReadinessProbe(**readiness_probe_dict) created_at = datetime.fromisoformat(d.pop("created_at")) return InferenceEngineVersion(readiness_probe=readiness_probe, created_at=created_at, **d) return [ make_engine_version({"inference_protocol": None, **api_engine_version.to_dict()}) for api_engine_version in response.data.data ], count
[docs] def get_engine_versions( project_ids: list[str] | None = None, engine_names: list[str] | None = None, engine_ids: list[str] | None = None, engine_version_ids: list[str] | None = None, default_versions_only: bool = False, limit: int = 100, offset: int = 0, ) -> list[InferenceEngineVersion]: """Get inference engines.""" engine_versions, _ = get_engine_versions_with_count( project_ids, engine_names, engine_ids, engine_version_ids, default_versions_only, limit, offset, ) return engine_versions
[docs] @mcp_setting(mutating=True) def set_default_engine_version(engine_version_id: str) -> None: """Set or clear the default version for an engine.""" _inference_engine_api().inference_engines_version_id_default_put(version_id=engine_version_id)
[docs] @mcp_setting(mutating=True) def add_supported_engine( model_id: str, selector: str | EngineSelector, ) -> None: """Add a supported inference engine to a model.""" match selector: case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name): engine_id = None case str(engine_id): org_name, project_name, engine_name = None, None, None _inference_engine_api().models_model_id_inference_engines_post( model_id=model_id, body=models_api.AddSupportedInferenceEnginePostRequest( engine_id=engine_id, org_name=org_name, project_name=project_name, engine_name=engine_name, ), )
[docs] @mcp_setting(mutating=True) def remove_supported_engine( model_id: str, selector: str | EngineSelector, ) -> None: """Remove a supported inference engine to a model.""" match selector: case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name): engine_id = None case str(engine_id): org_name, project_name, engine_name = None, None, None _inference_engine_api().models_model_id_inference_engines_delete( model_id=model_id, body=models_api.RemoveSupportedInferenceEngine( engine_id=engine_id, org_name=org_name, project_name=project_name, engine_name=engine_name, ), )
[docs] def get_supported_engines(model_id: str) -> list[InferenceEngine]: """Get the supported inference engines for a model.""" return [ InferenceEngine(**engine.to_dict()) for engine in _inference_engine_api() .models_model_id_inference_engines_get(model_id=model_id) .data or [] ]
[docs] def get_engine_version_by_selector( org_name: str, project_name: str, engine_name: str, version: str | None = None ) -> InferenceEngineVersion | None: """Resolve an engine selector.""" try: engine_version = ( _inference_engine_api() .inference_engines_selector_get( org_name=org_name, project_name=project_name, engine_name=engine_name, version=version, ) .data ) except NotFoundException: return None return InferenceEngineVersion(**engine_version.to_dict())
# For backwards compatibility: _create_engine_version = create_engine_version _delete_engine_version = delete_engine_version _get_engine_versions_with_count = get_engine_versions_with_count _get_engine_versions = get_engine_versions _set_default_engine_version = set_default_engine_version _add_supported_engine = add_supported_engine _remove_supported_engine = remove_supported_engine _get_supported_engines = get_supported_engines _get_engine_version_by_selector = get_engine_version_by_selector