import requests
import chariot._apis
from chariot import mcp_setting
from chariot.inference_store import _utils, models
from chariot.inference_store.upload import upload_data
from chariot_api._openapi.inferencestore.exceptions import NotFoundException
__all__ = [
    "create_inference",
    "filter_inferences",
    "count_inferences",
    "get_inference",
    "delete_inference",
    "create_inference_storage_request",
    "bulk_delete_inferences",
]
[docs]
@mcp_setting(ignore=True)
def create_inference(
    model_id: str, request_body: models.NewInferenceStorageRequest
) -> models.Inference | None:
    """Create a new inference.
    :param model_id: The model id.
    :type model_id: str
    :param request_body: The inference and metadata.
    :type request_body: models.NewInferenceStorageRequest
    :return: The inference.
    :rtype: Optional[models.Inference]
    """
    response = chariot._apis.inferencestore.inference_api.inference_model_id_post(
        model_id=model_id, body=request_body.model_dump()
    )
    if not response.data:
        return None
    return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference) 
[docs]
def filter_inferences(
    model_id: str, request_body: models.NewGetInferencesRequest
) -> list[models.Inference]:
    """Get inferences for a model, optionally matching a series of filters. Each record
    returned corresponds to an inference request response pair. Inference responses are
    stored in the model's native format.
    For example, a record returned might have the following inference blob:
        {"detection_boxes": [[10, 10, 20, 20], [30, 30, 40, 40]], "detection_scores": [0.9, 0.95], "detection_labels": ["cat", "dog"]}
    :param model_id: The model id.
    :type model_id: str
    :param request_body: The inference filter to apply upon the full collection of inferences stored for the model.
    :type request_body: models.NewGetInferencesRequest
    :return: The collection of inferences that met the filter criteria.
    :rtype: List[models.Inference]
    """
    response = chariot._apis.inferencestore.inference_api.inference_model_id_filter_post(
        model_id=model_id, body=request_body.model_dump()
    )
    if not response.data:
        return []
    return [_utils.convert_to_dataclass(d.model_dump(), models.Inference) for d in response.data] 
[docs]
def count_inferences(model_id: str, request_body: models.NewGetInferencesRequest) -> int:
    """Get the count of all inferences for a model, optionally matching a series of filters.
    :param model_id: The model id.
    :type model_id: str
    :param request_body: The inference filter to apply upon the full collection of inferences stored for the model.
    :type request_body: models.NewGetInferencesRequest
    :return: The record count.
    :rtype: int
    """
    request_body.pagination = models.Pagination(limit=1, include_record_count=True)
    response = (
        chariot._apis.inferencestore.inference_api.inference_model_id_filter_post_with_http_info(
            model_id=model_id, body=request_body.model_dump()
        )
    )
    return int(response.headers.get("X-Record-Count", 0)) 
[docs]
def get_inference(
    model_id: str, inference_id: str, presign: bool = False
) -> models.Inference | None:
    """Get details about a single inference.
    :param model_id: The model id.
    :type model_id: str
    :param inference_id: The inference id.
    :type inference_id: str
    :param presign: Whether to include a presigned url to the image/text.
    :type presign: bool
    :return: The inference.
    :rtype: Optional[models.Inference]
    """
    response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_get(
        model_id=model_id, inference_id=inference_id, presign=presign
    )
    if not response.data:
        return None
    return _utils.convert_to_dataclass(response.data.model_dump(), models.Inference) 
[docs]
@mcp_setting(mutating=True)
def delete_inference(model_id: str, inference_id: str) -> str | None:
    """Delete an inference and all associated metadata.
    :param model_id: The model id.
    :type model_id: str
    :param inference_id: The inference id.
    :type inference_id: str
    :return: The deleted inference id.
    :rtype: Optional[str]
    """
    response = chariot._apis.inferencestore.inference_api.inference_model_id_inference_id_delete(
        model_id=model_id, inference_id=inference_id
    )
    if not response.data:
        return None
    return response.data 
[docs]
@mcp_setting(ignore=True)
def create_inference_storage_request(
    model_id: str,
    inference_id: str,
    data: str | bytes | None,
    metadata: models.NewInferenceAndMetadataCollection,
    is_protected: bool = False,
) -> models.NewInferenceStorageRequest | None:
    """Create an inference via the API. This function will upload the supplied data
    to blob storage if the data field is specified.
    :param model_id: The model id.
    :type model_id: str
    :param inference_id: The inference id.
    :type inference_id: str
    :param data: Optionally, the data inferred upon.
    :type data: Optional[Union[str, bytes]]
    :param metadata: The inference and metadata.
    :type metadata: models.NewInferenceAndMetadataCollection
    :param is_protected: Whether the inference and its associated data should be protected from deletion by retention policy
    :type is_protected: bool
    :return: The inference.
    :rtype: Optional[models.Inference]
    """
    request = models.NewInferenceStorageRequest(
        model_id=model_id,
        inference_id=inference_id,
        data=metadata,
        is_protected=is_protected,
    )
    if data is not None:
        udata = upload_data(model_id)
        if udata is None:
            raise Exception("Failed to get presigned url")
        upload_data_response = requests.put(url=udata.data_presigned_url, data=data)
        if upload_data_response.status_code != 200:
            raise Exception(f"Unable to upload data file: {upload_data_response.text}")
        request.data_storage_key = udata.data_storage_key
    return request 
[docs]
@mcp_setting(mutating=True)
def bulk_delete_inferences(model_id: str, filter_: models.BaseInferenceFilter) -> list[str | None]:
    """Delete inferences and all associated metadata.
    For example, this can be used to delete all inferences for a given snapshot/split, like this::
        deleted_inference_ids = istore.bulk_delete_inferences(
            MY_MODEL_ID,
            istore.BaseInferenceFilter(
                metadata_filter=[
                    istore.MetadataFilter(
                        key="dataset_snapshot_id",
                        type=istore.MetadataFilterType.STRING,
                        operator=istore.MetadataFilterOperator.EQUAL,
                        value=MY_SNAPSHOT_ID,
                    ),
                    istore.MetadataFilter(
                        key="dataset_snapshot_split",
                        type=istore.MetadataFilterType.STRING,
                        operator=istore.MetadataFilterOperator.EQUAL,
                        value=MY_SNAPSHOT_SPLIT,
                    ),
                ]
            ),
        )
    :param model_id: The model id.
    :type model_id: str
    :param filter: A filter describing which inferences to delete.
    :type filter: models.BaseInferenceFilter
    :return: The deleted inference ids.
    :rtype: list[str | None]
    """
    deleted_inference_ids = []
    while True:
        inference_filter = models.NewGetInferencesRequest(
            filters=filter_, pagination=models.Pagination(limit=100)
        )
        try:
            data = filter_inferences(model_id=model_id, request_body=inference_filter)
        except NotFoundException:
            break
        n = len(data)
        if n == 0:
            break
        for inf in data:
            deleted_inference_id = delete_inference(
                model_id=inf.model_id, inference_id=inf.inference_id
            )
            deleted_inference_ids.append(deleted_inference_id)
    return deleted_inference_ids