import requests
import chariot._apis
from chariot.inference_store import _utils, models
from chariot.inference_store.upload import upload_data
from chariot_api._openapi.inferencestore.exceptions import NotFoundException
__all__ = [
"create_inference",
"filter_inferences",
"count_inferences",
"get_inference",
"delete_inference",
"create_inference_storage_request",
"bulk_delete_inferences",
]
[docs]
def create_inference(
model_id: str, request_body: models.NewInferenceStorageRequest
) -> models.Inference | None:
"""Create a new inference.
:param model_id: The model id.
:type model_id: str
:param request_body: The inference and metadata.
:type request_body: models.NewInferenceStorageRequest
:return: The inference.
:rtype: Optional[models.Inference]
"""
response = chariot._apis.inferencestore.inference_api.inference_model_id_post(
model_id=model_id, body=request_body.model_dump()
)
if not response.data:
return None
return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference)
[docs]
def filter_inferences(
model_id: str, request_body: models.NewGetInferencesRequest
) -> list[models.Inference]:
"""Get inferences for a model, optionally matching a series of filters. Each record
returned corresponds to an inference request response pair. Inference responses are
stored in the model's native format.
For example, a record returned might have the following inference blob:
{"detection_boxes": [[10, 10, 20, 20], [30, 30, 40, 40]], "detection_scores": [0.9, 0.95], "detection_labels": ["cat", "dog"]}
:param model_id: The model id.
:type model_id: str
:param request_body: The inference filter to apply upon the full collection of inferences stored for the model.
:type request_body: models.NewGetInferencesRequest
:return: The collection of inferences that met the filter criteria.
:rtype: List[models.Inference]
"""
response = chariot._apis.inferencestore.inference_api.inference_model_id_filter_post(
model_id=model_id, body=request_body.model_dump()
)
if not response.data:
return []
return [_utils.convert_to_dataclass(d.model_dump(), models.Inference) for d in response.data]
[docs]
def count_inferences(model_id: str, request_body: models.NewGetInferencesRequest) -> int:
"""Get the count of all inferences for a model, optionally matching a series of filters.
:param model_id: The model id.
:type model_id: str
:param request_body: The inference filter to apply upon the full collection of inferences stored for the model.
:type request_body: models.NewGetInferencesRequest
:return: The record count.
:rtype: int
"""
response = (
chariot._apis.inferencestore.inference_api.inference_model_id_filter_post_with_http_info(
model_id=model_id, body=request_body.model_dump()
)
)
return int(response.headers.get("X-Record-Count"), 0)
[docs]
def get_inference(model_id: str, inference_id: str) -> models.Inference | None:
"""Get details about a single inference.
:param model_id: The model id.
:type model_id: str
:param inference_id: The inference id.
:type inference_id: str
:return: The inference.
:rtype: Optional[models.Inference]
"""
response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_get(
model_id=model_id, inference_id=inference_id
)
if not response.data:
return None
return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference)
[docs]
def delete_inference(model_id: str, inference_id: str) -> str | None:
"""Delete an inference and all associated metadata.
:param model_id: The model id.
:type model_id: str
:param inference_id: The inference id.
:type inference_id: str
:return: The deleted inference id.
:rtype: Optional[str]
"""
response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_delete(
model_id=model_id, inference_id=inference_id
)
if not response.data:
return None
return response.data
[docs]
def create_inference_storage_request(
model_id: str,
inference_id: str,
data: str | bytes | None,
metadata: models.NewInferenceAndMetadataCollection,
is_protected: bool = False,
) -> models.Inference | None:
"""Create an inference via the API. This function will upload the supplied data
to blob storage if the data field is specified.
:param model_id: The model id.
:type model_id: str
:param inference_id: The inference id.
:type inference_id: str
:param data: Optionally, the data inferred upon.
:type data: Optional[Union[str, bytes]]
:param metadata: The inference and metadata.
:type metadata: models.NewInferenceAndMetadataCollection
:param is_protected: Whether the inference and its associated data should be protected from deletion by retention policy
:type is_protected: bool
:return: The inference.
:rtype: Optional[models.Inference]
"""
if data is not None:
udata = upload_data(model_id)
upload_data_response = requests.put(url=udata.data_presigned_url, data=data)
if upload_data_response.status_code != 200:
raise Exception(f"Unable to upload data file: {upload_data_response.text}")
return models.NewInferenceStorageRequest(
model_id=model_id,
inference_id=inference_id,
data=metadata,
data_storage_key=udata.data_storage_key,
is_protected=is_protected,
)
[docs]
def bulk_delete_inferences(model_id: str, filter_: models.BaseInferenceFilter) -> list[str | None]:
"""Delete inferences and all associated metadata.
For example, this can be used to delete all inferences for a given snapshot/split, like this::
deleted_inference_ids = istore.bulk_delete_inferences(
MY_MODEL_ID,
istore.BaseInferenceFilter(
metadata_filter=[
istore.MetadataFilter(
key="dataset_snapshot_id",
type=istore.MetadataFilterType.STRING,
operator=istore.MetadataFilterOperator.EQUAL,
value=MY_SNAPSHOT_ID,
),
istore.MetadataFilter(
key="dataset_snapshot_split",
type=istore.MetadataFilterType.STRING,
operator=istore.MetadataFilterOperator.EQUAL,
value=MY_SNAPSHOT_SPLIT,
),
]
),
)
:param model_id: The model id.
:type model_id: str
:param filter: A filter describing which inferences to delete.
:type filter: models.BaseInferenceFilter
:return: The deleted inference ids.
:rtype: list[str | None]
"""
deleted_inference_ids = []
while True:
inference_filter = models.NewGetInferencesRequest(
filters=filter_, pagination=models.Pagination(limit=100)
)
try:
data = filter_inferences(model_id=model_id, request_body=inference_filter)
except NotFoundException:
break
n = len(data)
if n == 0:
break
for inf in data:
deleted_inference_id = delete_inference(
model_id=inf.model_id, inference_id=inf.inference_id
)
deleted_inference_ids.append(deleted_inference_id)
return deleted_inference_ids