Source code for chariot.inference_store.inference

import requests

import chariot._apis
from chariot.inference_store import _utils, models
from chariot.inference_store.upload import upload_data
from chariot_api._openapi.inferencestore.exceptions import NotFoundException

__all__ = [
    "create_inference",
    "filter_inferences",
    "count_inferences",
    "get_inference",
    "delete_inference",
    "create_inference_storage_request",
    "bulk_delete_inferences",
]


[docs] def create_inference( model_id: str, request_body: models.NewInferenceStorageRequest ) -> models.Inference | None: """Create a new inference. :param model_id: The model id. :type model_id: str :param request_body: The inference and metadata. :type request_body: models.NewInferenceStorageRequest :return: The inference. :rtype: Optional[models.Inference] """ response = chariot._apis.inferencestore.inference_api.inference_model_id_post( model_id=model_id, body=request_body.model_dump() ) if not response.data: return None return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference)
[docs] def filter_inferences( model_id: str, request_body: models.NewGetInferencesRequest ) -> list[models.Inference]: """Get inferences for a model, optionally matching a series of filters. Each record returned corresponds to an inference request response pair. Inference responses are stored in the model's native format. For example, a record returned might have the following inference blob: {"detection_boxes": [[10, 10, 20, 20], [30, 30, 40, 40]], "detection_scores": [0.9, 0.95], "detection_labels": ["cat", "dog"]} :param model_id: The model id. :type model_id: str :param request_body: The inference filter to apply upon the full collection of inferences stored for the model. :type request_body: models.NewGetInferencesRequest :return: The collection of inferences that met the filter criteria. :rtype: List[models.Inference] """ response = chariot._apis.inferencestore.inference_api.inference_model_id_filter_post( model_id=model_id, body=request_body.model_dump() ) if not response.data: return [] return [_utils.convert_to_dataclass(d.model_dump(), models.Inference) for d in response.data]
[docs] def count_inferences(model_id: str, request_body: models.NewGetInferencesRequest) -> int: """Get the count of all inferences for a model, optionally matching a series of filters. :param model_id: The model id. :type model_id: str :param request_body: The inference filter to apply upon the full collection of inferences stored for the model. :type request_body: models.NewGetInferencesRequest :return: The record count. :rtype: int """ response = ( chariot._apis.inferencestore.inference_api.inference_model_id_filter_post_with_http_info( model_id=model_id, body=request_body.model_dump() ) ) return int(response.headers.get("X-Record-Count"), 0)
[docs] def get_inference(model_id: str, inference_id: str) -> models.Inference | None: """Get details about a single inference. :param model_id: The model id. :type model_id: str :param inference_id: The inference id. :type inference_id: str :return: The inference. :rtype: Optional[models.Inference] """ response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_get( model_id=model_id, inference_id=inference_id ) if not response.data: return None return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference)
[docs] def delete_inference(model_id: str, inference_id: str) -> str | None: """Delete an inference and all associated metadata. :param model_id: The model id. :type model_id: str :param inference_id: The inference id. :type inference_id: str :return: The deleted inference id. :rtype: Optional[str] """ response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_delete( model_id=model_id, inference_id=inference_id ) if not response.data: return None return response.data
[docs] def create_inference_storage_request( model_id: str, inference_id: str, data: str | bytes | None, metadata: models.NewInferenceAndMetadataCollection, is_protected: bool = False, ) -> models.Inference | None: """Create an inference via the API. This function will upload the supplied data to blob storage if the data field is specified. :param model_id: The model id. :type model_id: str :param inference_id: The inference id. :type inference_id: str :param data: Optionally, the data inferred upon. :type data: Optional[Union[str, bytes]] :param metadata: The inference and metadata. :type metadata: models.NewInferenceAndMetadataCollection :param is_protected: Whether the inference and its associated data should be protected from deletion by retention policy :type is_protected: bool :return: The inference. :rtype: Optional[models.Inference] """ if data is not None: udata = upload_data(model_id) upload_data_response = requests.put(url=udata.data_presigned_url, data=data) if upload_data_response.status_code != 200: raise Exception(f"Unable to upload data file: {upload_data_response.text}") return models.NewInferenceStorageRequest( model_id=model_id, inference_id=inference_id, data=metadata, data_storage_key=udata.data_storage_key, is_protected=is_protected, )
[docs] def bulk_delete_inferences(model_id: str, filter_: models.BaseInferenceFilter) -> list[str | None]: """Delete inferences and all associated metadata. For example, this can be used to delete all inferences for a given snapshot/split, like this:: deleted_inference_ids = istore.bulk_delete_inferences( MY_MODEL_ID, istore.BaseInferenceFilter( metadata_filter=[ istore.MetadataFilter( key="dataset_snapshot_id", type=istore.MetadataFilterType.STRING, operator=istore.MetadataFilterOperator.EQUAL, value=MY_SNAPSHOT_ID, ), istore.MetadataFilter( key="dataset_snapshot_split", type=istore.MetadataFilterType.STRING, operator=istore.MetadataFilterOperator.EQUAL, value=MY_SNAPSHOT_SPLIT, ), ] ), ) :param model_id: The model id. :type model_id: str :param filter: A filter describing which inferences to delete. :type filter: models.BaseInferenceFilter :return: The deleted inference ids. :rtype: list[str | None] """ deleted_inference_ids = [] while True: inference_filter = models.NewGetInferencesRequest( filters=filter_, pagination=models.Pagination(limit=100) ) try: data = filter_inferences(model_id=model_id, request_body=inference_filter) except NotFoundException: break n = len(data) if n == 0: break for inf in data: deleted_inference_id = delete_inference( model_id=inf.model_id, inference_id=inf.inference_id ) deleted_inference_ids.append(deleted_inference_id) return deleted_inference_ids