from json import loads
import chariot._apis
from chariot.bulk_inference import _utils, models
__all__ = [
"create_bulk_inference_job",
"get_bulk_inference_job",
]
[docs]
def create_bulk_inference_job(job_spec: models.NewBulkInferenceJobRequest) -> str | None:
"""Create a bulk inference job.
:param job_spec: The bulk inference job specification.
:type job_spec: models.NewBulkInferenceJobRequest
:return: The bulk inference job id.
:rtype: Optional[str]
"""
body = {
"project_id": job_spec.model_project_id,
"model_id": job_spec.model_id,
"inference_method": job_spec.inference_method,
"dataset_spec": {
"dataset_project_id": job_spec.dataset_project_id,
"dataset_id": job_spec.dataset_id,
"dataset_snapshot_id": job_spec.dataset_snapshot_id,
"dataset_version_split": job_spec.dataset_snapshot_split,
},
"evaluate_metrics": job_spec.evaluate_metrics,
"batch_size": job_spec.batch_size,
}
response = chariot._apis.models.inference_api.inference_execution_post_with_http_info(body=body)
if not response.raw_data:
return None
data = loads(response.raw_data)
return data["id"]["job_id"]
[docs]
def get_bulk_inference_job(job_id: str) -> models.BulkInferenceJob | None:
"""Get a bulk inference job.
:param job_id: The bulk inference job id.
:type job_id: str
:return: The bulk inference job id.
:rtype: Optional[models.BulkInferenceJob]
"""
response = chariot._apis.models.inference_api.inference_execution_id_get_with_http_info(job_id)
if not response.data:
return None
return _utils.convert_to_dataclass(response.data.model_dump(), models.BulkInferenceJob)