# ------------------------------------------------------------------------------
# Copyright Striveworks, Inc.
# All rights reserved.
# ------------------------------------------------------------------------------
# fmt:off
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
from chariot import _apis
from chariot_api._openapi import model_drift
from ._types import (
Check,
CheckID,
CheckJob,
CheckPoint,
CheckStatus,
DriftAlert,
DriftMetric,
JobID,
ModelID,
Monitor,
MonitorStatus,
)
__all__ = [
"Check",
"CheckPoint",
"DriftAlert",
"DriftMetric",
"ModelID",
"Monitor",
"MonitorStatus",
"MonitorAddSpec",
"MonitorUpdateSpec",
"CheckJob",
"check_latest",
"check_now",
"check_summary",
"check_window",
"monitor_add",
"monitor_update",
"monitor_remove",
"monitor_get"
]
[docs]
class MonitorAddSpec(BaseModel):
model_id: ModelID
monitor_status: MonitorStatus
drift_alert: DriftAlert
drift_metric: DriftMetric
window_seconds: int
minimum_cardinality: int | None = None
maximum_cardinality: int | None = None
data_source: str = "*"
[docs]
class MonitorUpdateSpec(BaseModel):
monitor_status: MonitorStatus | None = None
drift_alert: DriftAlert | None = None
drift_metric: DriftMetric | None = None
window_seconds: int | None = None
current_workflow: str | Literal["NONE"] | None = None
last_run: datetime | Literal["NONE"] | None = None
minimum_cardinality: int | None = None
maximum_cardinality: int | None = None
[docs]
def monitor_add(spec: MonitorAddSpec) -> None:
arg = model_drift.MonitorAdd(
model_id=str(spec.model_id),
monitor_status=model_drift.MonitorStatus(spec.monitor_status.value),
drift_alert=model_drift.DriftAlert(spec.drift_alert.value),
drift_metric=model_drift.DriftMetric(spec.drift_metric.value),
window_seconds=spec.window_seconds,
minimum_cardinality=spec.minimum_cardinality if spec.minimum_cardinality is not None else None,
maximum_cardinality=spec.maximum_cardinality if spec.maximum_cardinality is not None else None,
)
api = _apis.model_drift.monitor_api
api.post_monitors_model_drift_v1_monitors_post(arg)
[docs]
def monitor_remove(model_id: ModelID, metric: DriftMetric) -> None:
api = _apis.model_drift.monitor_api
api.delete_monitors_model_drift_v1_monitors_model_id_metric_delete(
model_id, metric.value
)
def _monitor_from_record(record: model_drift.MonitorRecord) -> Monitor:
return Monitor(
monitor_status=MonitorStatus(record.monitor_status.value),
model_id=ModelID(record.model_id),
drift_alert=DriftAlert(record.drift_alert.value),
drift_metric=DriftMetric(record.drift_metric.value),
window_seconds=record.window_seconds,
current_workflow=record.current_workflow,
last_run=record.last_run if record.last_run is not None else None,
minimum_cardinality=record.minimum_cardinality if record.minimum_cardinality is not None else None,
maximum_cardinality=record.maximum_cardinality if record.maximum_cardinality is not None else None,
time_create=record.time_create,
time_modify=record.time_modify,
time_delete=record.time_delete if record.time_delete is not None else None,
)
[docs]
def monitor_get(model_id: ModelID) -> list[Monitor]:
api = _apis.model_drift.monitor_api
model_id = str(model_id)
monitor_result = api.get_monitors_model_drift_v1_monitors_model_id_get(model_id)
return list(map(_monitor_from_record, monitor_result.data))
[docs]
def monitor_update(model_id: ModelID, metric: DriftMetric, spec: MonitorUpdateSpec) -> None:
update = model_drift.MonitorUpdate(
monitor_status=model_drift.MonitorStatus(spec.monitor_status.value) if spec.monitor_status is not None else None,
drift_alert=model_drift.DriftAlert(spec.drift_alert.value) if spec.drift_alert is not None else None,
drift_metric=model_drift.DriftMetric(spec.drift_metric) if spec.drift_metric is not None else None,
window_seconds=spec.window_seconds if spec.window_seconds is not None else None,
current_workflow=model_drift.CurrentWorkflow(spec.current_workflow) if spec.current_workflow is not None else None,
last_run=model_drift.LastRun(spec.last_run) if spec.last_run is not None else None,
minimum_cardinality=spec.minimum_cardinality if spec.minimum_cardinality is not None else None,
maximum_cardinality=spec.maximum_cardinality if spec.maximum_cardinality is not None else None,
)
api = _apis.model_drift.monitor_api
api.patch_monitors_model_drift_v1_monitors_model_id_metric_patch(
model_id, metric.value, update
)
def _check_from_record(record: model_drift.CheckRecord) -> Check:
return Check(
check_id=CheckID(record.check_id),
model_id=ModelID(record.model_id),
check_status=CheckStatus(record.check_status.value),
window_begin=record.window_begin,
window_end=record.window_end,
metric=DriftMetric(record.metric.value),
attributes=record.attributes.to_dict()
if record.attributes is not None else dict(),
time_create=record.time_create,
time_modify=record.time_modify,
time_delete=record.time_delete.actual_instance if record.time_delete is not None else None,
)
[docs]
def check_latest(model_id: ModelID) -> list[Check]:
api = _apis.model_drift.check_api
model_id = str(model_id)
check_result = api.get_latest_checks_model_drift_v1_checks_model_id_latest_get(
model_id, include_logs=False
)
return list(map(_check_from_record, check_result.data))
[docs]
def check_window(model_id: ModelID, metric: DriftMetric, time_begin: datetime, time_end: datetime) -> list[Check]:
api = _apis.model_drift.check_api
check_result = api.get_check_window_model_drift_v1_checks_model_id_metric_get(
model_id, metric.value, time_begin, time_end, include_logs=False
)
return list(map(_check_from_record, check_result.data))
[docs]
def check_summary(model_id: ModelID, metric: DriftMetric, time_begin: datetime, time_end: datetime) -> list[CheckPoint]:
api = _apis.model_drift.check_api
check_result = api.get_check_summary_model_drift_v1_checks_model_id_metric_summary_get_with_http_info(
model_id, metric.value, time_begin, time_end
)
return list(map(_check_from_record, check_result.data))
[docs]
def check_now(model_id: ModelID, metric: DriftMetric) -> CheckJob:
api = _apis.model_drift.drift_api
run = api.post_drift_check_model_drift_v1_drift_check_post(
model_drift.DriftCheckArgs(
model_id=str(model_id),
metric=model_drift.DriftMetric(metric.value)
)
)
return CheckJob(job_id=JobID(run.job_id))