# ------------------------------------------------------------------------------
# Copyright Striveworks, Inc.
# All rights reserved.
# ------------------------------------------------------------------------------
# fmt:off
import json
from pydantic import BaseModel
from chariot import _apis
from chariot_api._openapi import representer
from chariot.models import get_model_by_id
from ._types import (
DetectorJob,
DetectorJobState,
DriftError,
DriftMetric,
EmbeddingJob,
EmbeddingJobState,
ModelID,
)
__all__ = [
"DetectorJob",
"DetectorJobSpec",
"DetectorJobState",
"EmbeddingJob",
"EmbeddingJobSpec",
"EmbeddingJobState",
"ModelID",
"register_detector",
"detector_job_model",
"detector_job_run",
"detector_job_remove",
"detector_job_status",
"embedding_job_dataset",
"embedding_job_remove",
"embedding_job_run",
"embedding_job_status",
]
[docs]
class EmbeddingJobSpec(BaseModel):
snapshot_id: str
split: str | None = None
embedding_model_id: str | None = None
embedding_count: int | None = None
[docs]
class DetectorJobSpec(BaseModel):
project_id: str
model_id: str
metric: DriftMetric
embedding_count_per_dataset_version: int | None = None
def _embedding_model() -> str:
api = _apis.representer.registry_api
res = api.get_embedding_models_representer_v1_registry_models_get()
return res["data"]["generated_id"]
def _embedding_job_from_record(record: representer.EmbeddingJobRecord) -> EmbeddingJob:
if snapshot_id := record.snapshot_id:
snapshot_id = record.snapshot_id
if split := record.split:
split = record.split
if workflow_id := record.workflow_id:
workflow_id = record.workflow_id
return EmbeddingJob(
generated_id=record.generated_id,
project_id=record.project_id,
dataset_id=record.dataset_id,
dataset_version_id=record.dataset_version_id,
snapshot_id=snapshot_id,
split=split,
embedding_model_id=record.embedding_model_id,
workflow_id=workflow_id,
job_state=EmbeddingJobState(record.job_state.value),
attributes=record.attributes
)
[docs]
def embedding_job_run(spec: EmbeddingJobSpec) -> EmbeddingJob:
if spec.embedding_model_id is None:
spec.embedding_model_id = _embedding_model()
api = _apis.representer.embedding_api
args = representer.CreateEmbeddingArgs(
snapshot_id=spec.snapshot_id,
split=spec.split,
embedding_model_id=spec.embedding_model_id,
embedding_count=spec.embedding_count
)
res = api.create_embedding_job_representer_v1_embedding_post(args)
return _embedding_job_from_record(res.data)
[docs]
def embedding_job_status(job_id: str) -> EmbeddingJob | None:
api = _apis.representer.embedding_api
try:
res = api.get_embedding_jobs_representer_v1_embedding_jobs_get(
job_id=job_id
)
if rec := res.data.actual_instance:
return _embedding_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def embedding_job_dataset(
snapshot_id: str,
split: str | None = None
) -> EmbeddingJob | None:
api = _apis.representer.embedding_api
try:
res = api.get_embedding_jobs_representer_v1_embedding_jobs_get(
snapshot_id=snapshot_id,
split=split
)
if rec := res.data.actual_instance:
return _embedding_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def embedding_job_remove(
snapshot_id: str = None,
split: str | None = None
) -> None:
api = _apis.representer.embedding_api
count = api.delete_embedding_jobs_representer_v1_embedding_jobs_delete(
snapshot_id=snapshot_id,
split=split
)
if count != 1:
raise DriftError(f"failed to remove embedding job (count = {count})")
def _detector_job_from_record(record: representer.DriftDetectorJobRecord) -> DetectorJob:
return DetectorJob(
generated_id=record.generated_id,
project_id=record.project_id,
model_id=record.model_id,
metric=record.metric.value,
workflow_id=str(record.workflow_id),
job_state=DetectorJobState(record.job_state.value),
storage_url=str(record.storage_url),
attributes=record.attributes
)
[docs]
def detector_job_run(spec: DetectorJobSpec) -> DetectorJob:
api = _apis.representer.detector_api
args = representer.CreateDriftDetectorArgs(
project_id=spec.project_id,
model_id=spec.model_id,
metric=spec.metric.value,
embedding_count_per_dataset_version=spec.embedding_count_per_dataset_version,
create_needed_embeddings=False
)
res = api.create_drift_detector_job_representer_v1_detector_post(args)
return _detector_job_from_record(res.data)
[docs]
def detector_job_status(job_id: str) -> DetectorJob | None:
api = _apis.representer.detector_api
try:
res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get(
job_id=job_id
)
if rec := res.data.actual_instance:
return _detector_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def detector_job_model(model_id: ModelID, metric: DriftMetric) -> DetectorJob | None:
api = _apis.representer.detector_api
try:
res = api.get_drift_detector_jobs_representer_v1_detector_jobs_get(
model_id=model_id,
metric=metric.value
)
if rec := res.data.actual_instance:
return _detector_job_from_record(rec)
return None
except representer.exceptions.NotFoundException:
return None
[docs]
def detector_job_remove(model_id: ModelID, metric: DriftMetric) -> None:
api = _apis.representer.detector_api
count = api.delete_drift_detector_jobs_representer_v1_detector_jobs_delete(
model_id, metric.value
)
if count != 1:
raise DriftError(f"failed to remove detector job (count = {count})")
[docs]
@_apis.login_required
def register_detector(
model_id: str,
metric: DriftMetric,
storage_url: str,
):
"""Register a detector within the representer service.
:param model_id: The model where the detector is stored.
:type model_id: str
:param metric: The metric that the detector is computing.
:type metric: DriftMetric
:param storage_url: The url where the detector is stored.
:type storage_url: str
"""
model = get_model_by_id(model_id)
# GET - check if storage url already registered
registered_detectors = _apis.representer.detector_api.get_drift_detector_jobs_representer_v1_detector_jobs_get_without_preload_content(
model_id=model.id,
metric=metric.value,
)
if registered_detectors.status == 200:
data = json.load(registered_detectors)["data"]
status = data.get("job_state")
if status and status == DetectorJobState.COMPLETED.value:
return
# POST - register detector with the representer service
_ = _apis.representer.detector_api.upload_drift_detector_representer_v1_detector_upload_post(
{
"project_id": model.project_id,
"model_id": model.id,
"metric": metric.value,
"storage_url": storage_url,
}
)