from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, cast
import chariot._apis
from chariot_api._openapi.models.api_response import ApiResponse
from chariot_api._openapi.models.exceptions import ApiException
from chariot_api._openapi.models.models import (
InputEvaluation,
OutputGetEvaluationsResponse,
)
__all__ = [
"create_evaluation",
"get_evaluations",
]
@dataclass
class CreateEvaluationResponse:
status: int
message: str | None
@dataclass
class Evaluation:
snapshot_id: str
split: str
created_at: str
evaluation_data: Any
@dataclass
class GetEvaluationsResponse:
status: int
message: str | None
evaluations: list[Evaluation]
[docs]
def create_evaluation(
model_id: str, snapshot_id: str, split: str, evaluation_data: Any
) -> CreateEvaluationResponse:
try:
response = cast(
ApiResponse[None],
chariot._apis.models.models_api.models_id_evaluations_post_with_http_info( # type: ignore
model_id,
InputEvaluation(
evaluation_data=evaluation_data,
snapshot_id=snapshot_id,
split=split,
),
),
)
except ApiException as e:
if e.status == HTTPStatus.CONFLICT:
message = f"Evaluation already exists for snapshot/split {snapshot_id}/{split}"
else:
message = str(e)
return CreateEvaluationResponse(status=e.status, message=message)
return CreateEvaluationResponse(status=response.status_code, message=response.raw_data.decode())
[docs]
def get_evaluations(model_id: str) -> GetEvaluationsResponse:
try:
response = cast(
ApiResponse[OutputGetEvaluationsResponse],
chariot._apis.models.models_api.models_id_evaluations_get_with_http_info( # type: ignore
model_id
),
)
status = response.status_code
message = response.raw_data.decode() if response.status_code != 200 else None
evaluations = [
Evaluation(**output_evaluation.model_dump())
for output_evaluation in response.data.data or []
]
except ApiException as e:
status = e.status
message = "Model not found" if status == HTTPStatus.NOT_FOUND else str(e)
evaluations = []
return GetEvaluationsResponse(status=status, message=message, evaluations=evaluations)