Source code for chariot.bulk_inference.job

from json import loads

import chariot._apis
from chariot.bulk_inference import _utils, models

__all__ = [
    "create_bulk_inference_job",
    "get_bulk_inference_job",
]


[docs] def create_bulk_inference_job(job_spec: models.NewBulkInferenceJobRequest) -> str | None: """Create a bulk inference job. :param job_spec: The bulk inference job specification. :type job_spec: models.NewBulkInferenceJobRequest :return: The bulk inference job id. :rtype: Optional[str] """ body = { "project_id": job_spec.model_project_id, "model_id": job_spec.model_id, "inference_method": job_spec.inference_method, "dataset_spec": { "dataset_project_id": job_spec.dataset_project_id, "dataset_id": job_spec.dataset_id, "dataset_snapshot_id": job_spec.dataset_snapshot_id, "dataset_version_split": job_spec.dataset_snapshot_split, }, "evaluate_metrics": job_spec.evaluate_metrics, "batch_size": job_spec.batch_size, } response = chariot._apis.models.inference_api.inference_execution_post_with_http_info(body=body) if not response.raw_data: return None data = loads(response.raw_data) return data["id"]["job_id"]
[docs] def get_bulk_inference_job(job_id: str) -> models.BulkInferenceJob | None: """Get a bulk inference job. :param job_id: The bulk inference job id. :type job_id: str :return: The bulk inference job id. :rtype: Optional[models.BulkInferenceJob] """ response = chariot._apis.models.inference_api.inference_execution_id_get_with_http_info(job_id) if not response.data: return None return _utils.convert_to_dataclass(response.data.model_dump(), models.BulkInferenceJob)