from collections.abc import Sequence
from chariot import _apis
from chariot.models._http_utils import _post_http_request, join_url
from chariot.models._infer_v2 import RequestInput, create_inference_request
from chariot.models.enum import TaskType as ModelsTaskType
from chariot.models.exceptions import ActionUnsupportedByCurrentModelError
from chariot.models.inference import Action
from chariot_api._openapi.models.rest import ApiException
from inference_protocol import (
HTTPResponse,
InferenceResponse,
Protocol,
create_http_request,
parse_http_response,
)
from inference_protocol import (
TaskType as ProtocolTaskType,
)
from inference_protocol.entity import InferenceOutput
from inference_protocol.request_metadata import RequestMetadataInput
[docs]
class InferenceServer:
"""A handle to a running inference server."""
def __init__(
self,
model_id: str,
model_name: str,
internal_base_url: str,
external_base_url: str,
use_internal_url: bool,
protocol: str | None,
task: ModelsTaskType,
actions: list[str],
class_labels: dict[str, int] | None,
):
self._model_id = model_id
self._model_name = model_name
self._internal_base_url = internal_base_url
self._external_base_url = external_base_url
self._use_internal_url = use_internal_url
self._task = task
self._protocol = protocol
self._actions = actions
self._class_labels = class_labels
self._stopped = False
@property
def model_id(self) -> str:
"""Model ID associated with this inference server."""
return self._model_id
@property
def model_name(self) -> str:
"""Model name associated with this inference server."""
return self._model_name
@property
def task(self) -> ModelsTaskType:
"""Model task type."""
return self._task
@property
def protocol(self) -> str | None:
"""Inference protocol used by this inference server."""
return self._protocol
@property
def internal_base_url(self) -> str:
"""Internal (cluster) base URL for the server."""
return self._internal_base_url
@property
def external_base_url(self) -> str:
"""External base URL for the server."""
return self._external_base_url
@property
def inference_base_url(self) -> str:
"""Server base URL selected by `use_internal_url`."""
return self._internal_base_url if self._use_internal_url else self._external_base_url
[docs]
def stop(self) -> None:
"""Stop the inference server."""
try:
_apis.models.models_api.models_id_isvc_delete(
id=self._model_id,
_request_timeout=15,
)
except ApiException as e:
if e.reason != "Not Found":
raise
self._stopped = True
[docs]
def infer(
self,
action: Action,
input_: RequestInput,
custom_metadata: RequestMetadataInput | Sequence[RequestMetadataInput] | None = None,
score_threshold: float | None = None,
timeout: int = 60,
**inference_kwargs, # NOTE: currently unused
) -> InferenceResponse[InferenceOutput]:
"""Run inference `action` on `input_`, returning an `InferenceResponse`."""
if self._stopped:
raise RuntimeError(
"This InferenceServer handle was stopped. "
"Call `model.wait_for_inference_server(...)` to get a new handle."
)
if action not in self._actions:
raise ActionUnsupportedByCurrentModelError(
f"This model's task type '{self._task}' does not supports the '{action}' action, "
f" please use one of the following actions: {self._actions}"
)
if self.protocol is None:
raise RuntimeError("Inference protocol required for inference")
# TODO (c.zaloom) - replace with `self.protocol in Protocol` for python>=3.12
try:
protocol = Protocol(self.protocol)
except ValueError:
raise RuntimeError(f"Inference protocol '{self.protocol}' is not currently supported.")
protocol_task = ProtocolTaskType(self.task.value)
protocol = Protocol(self._protocol)
protocol_task = ProtocolTaskType(self._task)
request = create_inference_request(
protocol=protocol,
task=self._task,
action=action,
input_=input_,
custom_metadata=custom_metadata,
score_threshold=score_threshold,
model_name=self._model_name,
**inference_kwargs,
)
http_request = create_http_request(protocol, protocol_task, request)
response = _post_http_request(
join_url(self.inference_base_url, http_request.url_suffix),
http_request.body,
headers=http_request.headers,
timeout=timeout,
)
http_response = HTTPResponse(body=response.content, headers=dict(response.headers))
return parse_http_response(
protocol,
protocol_task,
http_response,
self._class_labels,
)