Source code for chariot.drift.representer

# ------------------------------------------------------------------------------
# Copyright Striveworks, Inc.
# All rights reserved.
# ------------------------------------------------------------------------------
# fmt:off


from pydantic import BaseModel

from chariot import _apis
from chariot_api._openapi import representer

from ._types import (
    DetectorJob,
    DetectorJobState,
    DriftError,
    DriftMetric,
    EmbeddingJob,
    EmbeddingJobState,
    ModelID,
)

__all__ = [
    "DetectorJob",
    "DetectorJobSpec",
    "DetectorJobState",
    "EmbeddingJob",
    "EmbeddingJobSpec",
    "EmbeddingJobState",
    "ModelID",

    "detector_job_model",
    "detector_job_run",
    "detector_job_remove",
    "detector_job_status",
    "embedding_job_dataset",
    "embedding_job_remove",
    "embedding_job_run",
    "embedding_job_status",
]


[docs] class EmbeddingJobSpec(BaseModel): snapshot_id: str split: str | None = None embedding_model_id: str | None = None embedding_count: int | None = None
[docs] class DetectorJobSpec(BaseModel): project_id: str model_id: str metric: DriftMetric embedding_count_per_dataset_version: int | None = None
def _embedding_model() -> str: api = _apis.representer.registry_api res = api.get_embedding_models_representer_v1_registry_models_get() return res["data"]["generated_id"] def _embedding_job_from_record(record: representer.EmbeddingJobRecord) -> EmbeddingJob: if snapshot_id := record.snapshot_id: snapshot_id = record.snapshot_id if split := record.split: split = record.split if workflow_id := record.workflow_id: workflow_id = record.workflow_id return EmbeddingJob( generated_id=record.generated_id, project_id=record.project_id, dataset_id=record.dataset_id, dataset_version_id=record.dataset_version_id, snapshot_id=snapshot_id, split=split, embedding_model_id=record.embedding_model_id, workflow_id=workflow_id, job_state=EmbeddingJobState(record.job_state.value), attributes=record.attributes )
[docs] def embedding_job_run(spec: EmbeddingJobSpec) -> EmbeddingJob: if spec.embedding_model_id is None: spec.embedding_model_id = _embedding_model() api = _apis.representer.embedding_api args = representer.CreateEmbeddingArgs( snapshot_id=spec.snapshot_id, split=spec.split, embedding_model_id=spec.embedding_model_id, embedding_count=spec.embedding_count ) res = api.create_embedding_job_representer_v1_embedding_post(args) return _embedding_job_from_record(res.data)
[docs] def embedding_job_status(job_id: str) -> EmbeddingJob | None: api = _apis.representer.embedding_api try: res = api.get_embedding_jobs_representer_v1_embedding_jobs_get( job_id=job_id ) if rec := res.data.actual_instance: return _embedding_job_from_record(rec) return None except representer.exceptions.NotFoundException: return None
[docs] def embedding_job_dataset( snapshot_id: str, split: str | None = None ) -> EmbeddingJob | None: api = _apis.representer.embedding_api try: res = api.get_embedding_jobs_representer_v1_embedding_jobs_get( snapshot_id=snapshot_id, split=split ) if rec := res.data.actual_instance: return _embedding_job_from_record(rec) return None except representer.exceptions.NotFoundException: return None
[docs] def embedding_job_remove( snapshot_id: str = None, split: str | None = None ) -> None: api = _apis.representer.embedding_api count = api.delete_embedding_jobs_representer_v1_embedding_jobs_delete( snapshot_id=snapshot_id, split=split ) if count != 1: raise DriftError(f"failed to remove embedding job (count = {count})")
def _detector_job_from_record(record: representer.DriftDetectorJobRecord) -> DetectorJob: return DetectorJob( generated_id=record.generated_id, project_id=record.project_id, model_id=record.model_id, metric=record.metric.value, workflow_id=str(record.workflow_id), job_state=DetectorJobState(record.job_state.value), storage_url=str(record.storage_url), attributes=record.attributes )
[docs] def detector_job_run(spec: DetectorJobSpec) -> DetectorJob: api = _apis.representer.detector_api args = representer.CreateDriftDetectorArgs( project_id=spec.project_id, model_id=spec.model_id, metric=spec.metric.value, embedding_count_per_dataset_version=spec.embedding_count_per_dataset_version, create_needed_embeddings=False ) res = api.create_drift_detector_job_representer_v1_detector_post(args) return _detector_job_from_record(res.data)
[docs] def detector_job_status(job_id: str) -> DetectorJob | None: api = _apis.representer.detector_api try: res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get( job_id=job_id ) if rec := res.data.actual_instance: return _detector_job_from_record(rec) return None except representer.exceptions.NotFoundException: return None
[docs] def detector_job_model(model_id: ModelID, metric: DriftMetric) -> DetectorJob | None: api = _apis.representer.detector_api try: res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get( model_id=model_id, metric=metric.value ) if rec := res.data.actual_instance: return _detector_job_from_record(rec) return None except representer.exceptions.NotFoundException: return None
[docs] def detector_job_remove(model_id: ModelID, metric: DriftMetric) -> None: api = _apis.representer.detector_api count = api.delete_drift_detector_jobs_representer_v1_detector_jobs_delete( model_id, metric.value ) if count != 1: raise DriftError(f"failed to remove detector job (count = {count})")