Source code for chariot.models.inference_server

from collections.abc import Sequence

from chariot import _apis
from chariot.models._http_utils import _post_http_request, join_url
from chariot.models._infer_v2 import RequestInput, create_inference_request
from chariot.models.enum import TaskType as ModelsTaskType
from chariot.models.exceptions import ActionUnsupportedByCurrentModelError
from chariot.models.inference import Action
from chariot_api._openapi.models.rest import ApiException
from inference_protocol import (
    HTTPResponse,
    InferenceResponse,
    Protocol,
    create_http_request,
    parse_http_response,
)
from inference_protocol import (
    TaskType as ProtocolTaskType,
)
from inference_protocol.entity import InferenceOutput
from inference_protocol.request_metadata import RequestMetadataInput


[docs] class InferenceServer: """A handle to a running inference server.""" def __init__( self, model_id: str, model_name: str, internal_base_url: str, external_base_url: str, use_internal_url: bool, protocol: str | None, task: ModelsTaskType, actions: list[str], class_labels: dict[str, int] | None, ): self._model_id = model_id self._model_name = model_name self._internal_base_url = internal_base_url self._external_base_url = external_base_url self._use_internal_url = use_internal_url self._task = task self._protocol = protocol self._actions = actions self._class_labels = class_labels self._stopped = False @property def model_id(self) -> str: """Model ID associated with this inference server.""" return self._model_id @property def model_name(self) -> str: """Model name associated with this inference server.""" return self._model_name @property def task(self) -> ModelsTaskType: """Model task type.""" return self._task @property def protocol(self) -> str | None: """Inference protocol used by this inference server.""" return self._protocol @property def internal_base_url(self) -> str: """Internal (cluster) base URL for the server.""" return self._internal_base_url @property def external_base_url(self) -> str: """External base URL for the server.""" return self._external_base_url @property def inference_base_url(self) -> str: """Server base URL selected by `use_internal_url`.""" return self._internal_base_url if self._use_internal_url else self._external_base_url
[docs] def stop(self) -> None: """Stop the inference server.""" try: _apis.models.models_api.models_id_isvc_delete( id=self._model_id, _request_timeout=15, ) except ApiException as e: if e.reason != "Not Found": raise self._stopped = True
[docs] def infer( self, action: Action, input_: RequestInput, custom_metadata: RequestMetadataInput | Sequence[RequestMetadataInput] | None = None, score_threshold: float | None = None, timeout: int = 60, **inference_kwargs, # NOTE: currently unused ) -> InferenceResponse[InferenceOutput]: """Run inference `action` on `input_`, returning an `InferenceResponse`.""" if self._stopped: raise RuntimeError( "This InferenceServer handle was stopped. " "Call `model.wait_for_inference_server(...)` to get a new handle." ) if action not in self._actions: raise ActionUnsupportedByCurrentModelError( f"This model's task type '{self._task}' does not supports the '{action}' action, " f" please use one of the following actions: {self._actions}" ) if self.protocol is None: raise RuntimeError("Inference protocol required for inference") # TODO (c.zaloom) - replace with `self.protocol in Protocol` for python>=3.12 try: protocol = Protocol(self.protocol) except ValueError: raise RuntimeError(f"Inference protocol '{self.protocol}' is not currently supported.") protocol_task = ProtocolTaskType(self.task.value) protocol = Protocol(self._protocol) protocol_task = ProtocolTaskType(self._task) request = create_inference_request( protocol=protocol, task=self._task, action=action, input_=input_, custom_metadata=custom_metadata, score_threshold=score_threshold, model_name=self._model_name, **inference_kwargs, ) http_request = create_http_request(protocol, protocol_task, request) response = _post_http_request( join_url(self.inference_base_url, http_request.url_suffix), http_request.body, headers=http_request.headers, timeout=timeout, ) http_response = HTTPResponse(body=response.content, headers=dict(response.headers)) return parse_http_response( protocol, protocol_task, http_response, self._class_labels, )