import json
import chariot._apis
from chariot import mcp_setting
from ._utils import convert_to_dataclass
from .models import BulkInferenceJob, NewBulkInferenceJobRequest
__all__ = [
"create_bulk_inference_job",
"get_bulk_inference_job",
"get_bulk_inference_jobs",
]
[docs]
@mcp_setting(mutating=True)
def create_bulk_inference_job(job_spec: NewBulkInferenceJobRequest) -> str | None:
"""Create a bulk inference job.
:param job_spec: The bulk inference job specification.
:type job_spec: models.NewBulkInferenceJobRequest
:return: The bulk inference job id.
:rtype: Optional[str]
"""
body = {
"project_id": job_spec.model_project_id,
"model_id": job_spec.model_id,
"inference_method": job_spec.inference_method,
"dataset_spec": {
"dataset_snapshot_id": job_spec.dataset_snapshot_id,
"dataset_version_split": job_spec.dataset_snapshot_split,
},
"evaluate_metrics": job_spec.evaluate_metrics,
"batch_size": job_spec.batch_size,
}
response = chariot._apis.models.inference_api.inference_execution_post_with_http_info(body=body)
if not response.raw_data:
return None
data = json.loads(response.raw_data)
return data["id"]["job_id"]
[docs]
def get_bulk_inference_job(job_id: str) -> BulkInferenceJob | None:
"""Get a bulk inference job.
:param job_id: The bulk inference job id.
:type job_id: str
:return: The bulk inference job.
:rtype: Optional[BulkInferenceJob]
"""
response = chariot._apis.models.inference_api.inference_execution_id_get_with_http_info(job_id)
if not response.data:
return None
return convert_to_dataclass(response.data.model_dump(), BulkInferenceJob)
[docs]
def get_bulk_inference_jobs(
model_id: str,
dataset_version_id: str | None = None,
dataset_snapshot_id: str | None = None,
dataset_version_split: str | None = None,
inference_method: str | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[BulkInferenceJob]:
"""Get bulk inference jobs.
:param model_id: The model id.
:type model_id: str
:param dataset_version_id: Optional dataset version ID. Defaults to None.
:type dataset_version_id: Optional[str]
:param dataset_snapshot_id: Optional dataset snapshot ID. Defaults to None.
:type dataset_snapshot_id: Optional[str]
:param dataset_version_split: Optional dataset snapshot split. Defaults to None.
:type dataset_version_split: Optional[str]
:param inference_method: Optional model inference method. Defaults to None.
:type inference_method: Optional[str]
:param limit: Optional pagination limit. Defaults to None.
:type limit: Optional[int]
:param offset: Optional pagination offset. Defaults to None.
:type offset: Optional[int]
:return: .
:rtype: List[BulkInferenceJob]
"""
response = (
chariot._apis.models.inference_api.inference_execution_model_id_get_without_preload_content(
id=model_id,
dataset_version_id=dataset_version_id,
dataset_snapshot_id=dataset_snapshot_id,
dataset_version_split=dataset_version_split,
inference_method=inference_method,
limit=limit,
offset=offset,
)
)
if not response.data:
return []
jobs = json.loads(response.data)["data"]
if jobs is None:
return []
return [convert_to_dataclass(job, BulkInferenceJob) for job in jobs]