# ------------------------------------------------------------------------------
# Copyright Striveworks, Inc.
# All rights reserved.
# ------------------------------------------------------------------------------
# fmt:off
from pydantic import BaseModel
from chariot import _apis
from chariot_api._openapi import representer
from ._types import (
DetectorJob,
DetectorJobState,
DriftError,
DriftMetric,
EmbeddingJob,
EmbeddingJobState,
ModelID,
)
__all__ = [
"DetectorJob",
"DetectorJobSpec",
"DetectorJobState",
"EmbeddingJob",
"EmbeddingJobSpec",
"EmbeddingJobState",
"ModelID",
"detector_job_model",
"detector_job_run",
"detector_job_remove",
"detector_job_status",
"embedding_job_dataset",
"embedding_job_remove",
"embedding_job_run",
"embedding_job_status",
]
[docs]
class EmbeddingJobSpec(BaseModel):
snapshot_id: str
split: str | None = None
embedding_model_id: str | None = None
embedding_count: int | None = None
[docs]
class DetectorJobSpec(BaseModel):
project_id: str
model_id: str
metric: DriftMetric
embedding_count_per_dataset_version: int | None = None
def _embedding_model() -> str:
api = _apis.representer.registry_api
res = api.get_embedding_models_representer_v1_registry_models_get()
return res["data"]["generated_id"]
def _embedding_job_from_record(record: representer.EmbeddingJobRecord) -> EmbeddingJob:
if snapshot_id := record.snapshot_id:
snapshot_id = record.snapshot_id
if split := record.split:
split = record.split
if workflow_id := record.workflow_id:
workflow_id = record.workflow_id
return EmbeddingJob(
generated_id=record.generated_id,
project_id=record.project_id,
dataset_id=record.dataset_id,
dataset_version_id=record.dataset_version_id,
snapshot_id=snapshot_id,
split=split,
embedding_model_id=record.embedding_model_id,
workflow_id=workflow_id,
job_state=EmbeddingJobState(record.job_state.value),
attributes=record.attributes
)
[docs]
def embedding_job_run(spec: EmbeddingJobSpec) -> EmbeddingJob:
if spec.embedding_model_id is None:
spec.embedding_model_id = _embedding_model()
api = _apis.representer.embedding_api
args = representer.CreateEmbeddingArgs(
snapshot_id=spec.snapshot_id,
split=spec.split,
embedding_model_id=spec.embedding_model_id,
embedding_count=spec.embedding_count
)
res = api.create_embedding_job_representer_v1_embedding_post(args)
return _embedding_job_from_record(res.data)
[docs]
def embedding_job_status(job_id: str) -> EmbeddingJob | None:
api = _apis.representer.embedding_api
try:
res = api.get_embedding_jobs_representer_v1_embedding_jobs_get(
job_id=job_id
)
if rec := res.data.actual_instance:
return _embedding_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def embedding_job_dataset(
snapshot_id: str,
split: str | None = None
) -> EmbeddingJob | None:
api = _apis.representer.embedding_api
try:
res = api.get_embedding_jobs_representer_v1_embedding_jobs_get(
snapshot_id=snapshot_id,
split=split
)
if rec := res.data.actual_instance:
return _embedding_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def embedding_job_remove(
snapshot_id: str = None,
split: str | None = None
) -> None:
api = _apis.representer.embedding_api
count = api.delete_embedding_jobs_representer_v1_embedding_jobs_delete(
snapshot_id=snapshot_id,
split=split
)
if count != 1:
raise DriftError(f"failed to remove embedding job (count = {count})")
def _detector_job_from_record(record: representer.DriftDetectorJobRecord) -> DetectorJob:
return DetectorJob(
generated_id=record.generated_id,
project_id=record.project_id,
model_id=record.model_id,
metric=record.metric.value,
workflow_id=str(record.workflow_id),
job_state=DetectorJobState(record.job_state.value),
storage_url=str(record.storage_url),
attributes=record.attributes
)
[docs]
def detector_job_run(spec: DetectorJobSpec) -> DetectorJob:
api = _apis.representer.detector_api
args = representer.CreateDriftDetectorArgs(
project_id=spec.project_id,
model_id=spec.model_id,
metric=spec.metric.value,
embedding_count_per_dataset_version=spec.embedding_count_per_dataset_version,
create_needed_embeddings=False
)
res = api.create_drift_detector_job_representer_v1_detector_post(args)
return _detector_job_from_record(res.data)
[docs]
def detector_job_status(job_id: str) -> DetectorJob | None:
api = _apis.representer.detector_api
try:
res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get(
job_id=job_id
)
if rec := res.data.actual_instance:
return _detector_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def detector_job_model(model_id: ModelID, metric: DriftMetric) -> DetectorJob | None:
api = _apis.representer.detector_api
try:
res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get(
model_id=model_id,
metric=metric.value
)
if rec := res.data.actual_instance:
return _detector_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def detector_job_remove(model_id: ModelID, metric: DriftMetric) -> None:
api = _apis.representer.detector_api
count = api.delete_drift_detector_jobs_representer_v1_detector_jobs_delete(
model_id, metric.value
)
if count != 1:
raise DriftError(f"failed to remove detector job (count = {count})")