Source code for chariot.models.engines

from dataclasses import dataclass
from datetime import datetime
from typing import NotRequired, TypedDict

from chariot import _apis
from chariot_api._openapi.models import models as models_api
from chariot_api._openapi.models.api.inference_engine_api import InferenceEngineApi


[docs] @dataclass class InferenceEngine: """An inference engine.""" engine_id: str project_id: str name: str
[docs] class EnvVar(TypedDict): """An environment variable to be set in the engine container.""" name: str value: NotRequired[str]
[docs] class ServingArg(TypedDict): """A configurable parameter for this engine.""" name: str type: str display_text: NotRequired[str] default_value: NotRequired[str]
[docs] class InferenceKwarg(TypedDict): """A parameter to send to this engine's inference method.""" name: str type: str description: NotRequired[str]
[docs] @dataclass class InferenceEngineVersion: """An inference engine version.""" engine_version_id: str engine_id: str user_id: str project_id: str created_at: datetime name: str version: str container_registry_secret: str container_image_cpu: str container_image_gpu: str entrypoint: list[str] command: list[str] env: list[EnvVar] serving_args: list[ServingArg] serving_kwargs_supported: bool inference_kwargs: list[InferenceKwarg] documentation: str icon: str is_default: bool
[docs] @dataclass class EngineSelector: """The data needed to specify a unique inference engine.""" org_name: str project_name: str engine_name: str
def _inference_engine_api() -> InferenceEngineApi: return _apis.models.inference_engine_api # type: ignore def _create_engine_version( name: str, project_id: str, version: str, container_registry_secret: str | None = None, container_image_cpu: str | None = None, container_image_gpu: str | None = None, entrypoint: list[str] | None = None, command: list[str] | None = None, env: list[EnvVar] | None = None, serving_args: list[ServingArg] | None = None, serving_kwargs_supported: bool = False, inference_kwargs: list[InferenceKwarg] | None = None, documentation: str | None = None, icon: str | None = None, is_default: bool = False, ) -> str: """Create an inference engine.""" openapi_envs = [models_api.EnvVar(**env_var) for env_var in env or []] openapi_serving_args = [ models_api.ServingArg(**serving_arg) for serving_arg in serving_args or [] ] openapi_inference_kwargs = [ models_api.InferenceKwarg(**inference_kwarg) for inference_kwarg in inference_kwargs or [] ] return ( _inference_engine_api() .inference_engines_post( body=models_api.CreateEnginePostRequest( name=name, project_id=project_id, version=version, documentation=documentation, icon=icon, entrypoint=entrypoint or [], command=command or [], env=openapi_envs, container_image_cpu=container_image_cpu, container_image_gpu=container_image_gpu, container_registry_secret=container_registry_secret, inference_kwargs=openapi_inference_kwargs, serving_args=openapi_serving_args, serving_kwargs_supported=serving_kwargs_supported, is_default=is_default, ) ) .data.id ) def _delete_engine_version(engine_version_id: str) -> str: """Delete an inference engine.""" return _inference_engine_api().inference_engines_version_id_delete(version_id=engine_version_id) def _get_engine_versions_with_count( project_ids: list[str] | None = None, engine_names: list[str] | None = None, engine_ids: list[str] | None = None, engine_version_ids: list[str] | None = None, default_versions_only: bool = False, limit: int = 100, offset: int = 0, ) -> tuple[list[InferenceEngineVersion], int]: """Get inference engines together with a total count.""" response = ( _inference_engine_api().inference_engines_get_with_http_info( project_ids=",".join(project_ids) if project_ids else None, engine_names=",".join(engine_names) if engine_names else None, engine_ids=",".join(engine_ids) if engine_ids else None, engine_version_ids=",".join(engine_version_ids) if engine_version_ids else None, default_versions_only=default_versions_only, limit=limit, offset=offset, ) or [] ) if not response.headers: count = 0 else: count = int(response.headers.get("X-Record-Count", 0)) return [ InferenceEngineVersion(**api_engine_version.to_dict()) for api_engine_version in response.data.data ], count def _get_engine_versions( project_ids: list[str] | None = None, engine_names: list[str] | None = None, engine_ids: list[str] | None = None, engine_version_ids: list[str] | None = None, default_versions_only: bool = False, limit: int = 100, offset: int = 0, ): """Get inference engines.""" engine_versions, _ = _get_engine_versions_with_count( project_ids, engine_names, engine_ids, engine_version_ids, default_versions_only, limit, offset, ) return engine_versions def _set_default_engine_version(engine_version_id: str): """Set or clear the default version for an engine.""" _inference_engine_api().inference_engines_version_id_default_put(version_id=engine_version_id) def _add_supported_engine( model_id: str, selector: str | EngineSelector, ): """Add a supported inference engine to a model.""" match selector: case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name): engine_id = None case str(engine_id): org_name, project_name, engine_name = None, None, None _inference_engine_api().models_model_id_inference_engines_post( model_id=model_id, body=models_api.AddSupportedInferenceEnginePostRequest( engine_id=engine_id, org_name=org_name, project_name=project_name, engine_name=engine_name, ), ) def _remove_supported_engine( model_id: str, selector: str | EngineSelector, ): """Remove a supported inference engine to a model.""" match selector: case EngineSelector(org_name=org_name, project_name=project_name, engine_name=engine_name): engine_id = None case str(engine_id): org_name, project_name, engine_name = None, None, None _inference_engine_api().models_model_id_inference_engines_delete( model_id=model_id, body=models_api.RemoveSupportedInferenceEngine( engine_id=engine_id, org_name=org_name, project_name=project_name, engine_name=engine_name, ), ) def _get_supported_engines(model_id: str) -> list[InferenceEngine]: """Get the supported inference engines for a model.""" return [ InferenceEngine(**engine.to_dict()) for engine in _inference_engine_api() .models_model_id_inference_engines_get(model_id=model_id) .data or [] ]