Source code for chariot.models.evaluations

from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, cast

import chariot._apis
from chariot_api._openapi.models.api_response import ApiResponse
from chariot_api._openapi.models.exceptions import ApiException
from chariot_api._openapi.models.models import (
    InputEvaluation,
    OutputGetEvaluationsResponse,
)

__all__ = [
    "create_evaluation",
    "get_evaluations",
]


@dataclass
class CreateEvaluationResponse:
    status: int
    message: str | None


@dataclass
class Evaluation:
    snapshot_id: str
    split: str
    created_at: str
    evaluation_data: Any


@dataclass
class GetEvaluationsResponse:
    status: int
    message: str | None
    evaluations: list[Evaluation]


[docs] def create_evaluation( model_id: str, snapshot_id: str, split: str, evaluation_data: Any ) -> CreateEvaluationResponse: try: response = cast( ApiResponse[None], chariot._apis.models.models_api.models_id_evaluations_post_with_http_info( # type: ignore model_id, InputEvaluation( evaluation_data=evaluation_data, snapshot_id=snapshot_id, split=split, ), ), ) except ApiException as e: if e.status == HTTPStatus.CONFLICT: message = f"Evaluation already exists for snapshot/split {snapshot_id}/{split}" else: message = str(e) return CreateEvaluationResponse(status=e.status, message=message) return CreateEvaluationResponse(status=response.status_code, message=response.raw_data.decode())
[docs] def get_evaluations(model_id: str) -> GetEvaluationsResponse: try: response = cast( ApiResponse[OutputGetEvaluationsResponse], chariot._apis.models.models_api.models_id_evaluations_get_with_http_info( # type: ignore model_id ), ) status = response.status_code message = response.raw_data.decode() if response.status_code != 200 else None evaluations = [ Evaluation(**output_evaluation.model_dump()) for output_evaluation in response.data.data or [] ] except ApiException as e: status = e.status message = "Model not found" if status == HTTPStatus.NOT_FOUND else str(e) evaluations = [] return GetEvaluationsResponse(status=status, message=message, evaluations=evaluations)