Source code for chariot.drift.monitor

# ------------------------------------------------------------------------------
# Copyright Striveworks, Inc.
# All rights reserved.
# ------------------------------------------------------------------------------
# fmt:off

from datetime import datetime
from typing import Literal

from pydantic import BaseModel

from chariot import _apis
from chariot_api._openapi import model_drift

from ._types import (
    Check,
    CheckID,
    CheckJob,
    CheckPoint,
    CheckStatus,
    DriftAlert,
    DriftMetric,
    JobID,
    ModelID,
    Monitor,
    MonitorStatus,
)

__all__ = [
    "Check",
    "CheckPoint",
    "DriftAlert",
    "DriftMetric",
    "ModelID",
    "Monitor",
    "MonitorStatus",
    "MonitorAddSpec",
    "MonitorUpdateSpec",
    "CheckJob",

    "check_latest",
    "check_now",
    "check_summary",
    "check_window",
    "monitor_add",
    "monitor_update",
    "monitor_remove",
    "monitor_get"
]


[docs] class MonitorAddSpec(BaseModel): model_id: ModelID monitor_status: MonitorStatus drift_alert: DriftAlert drift_metric: DriftMetric window_seconds: int minimum_cardinality: int | None = None maximum_cardinality: int | None = None data_source: str = "*"
[docs] class MonitorUpdateSpec(BaseModel): monitor_status: MonitorStatus | None = None drift_alert: DriftAlert | None = None drift_metric: DriftMetric | None = None window_seconds: int | None = None current_workflow: str | Literal["NONE"] | None = None last_run: datetime | Literal["NONE"] | None = None minimum_cardinality: int | None = None maximum_cardinality: int | None = None
[docs] def monitor_add(spec: MonitorAddSpec) -> None: arg = model_drift.MonitorAdd( model_id=str(spec.model_id), monitor_status=model_drift.MonitorStatus(spec.monitor_status.value), drift_alert=model_drift.DriftAlert(spec.drift_alert.value), drift_metric=model_drift.DriftMetric(spec.drift_metric.value), window_seconds=spec.window_seconds, minimum_cardinality=spec.minimum_cardinality if spec.minimum_cardinality is not None else None, maximum_cardinality=spec.maximum_cardinality if spec.maximum_cardinality is not None else None, ) api = _apis.model_drift.monitor_api api.post_monitors_model_drift_v1_monitors_post(arg)
[docs] def monitor_remove(model_id: ModelID, metric: DriftMetric) -> None: api = _apis.model_drift.monitor_api api.delete_monitors_model_drift_v1_monitors_model_id_metric_delete( model_id, metric.value )
def _monitor_from_record(record: model_drift.MonitorRecord) -> Monitor: return Monitor( monitor_status=MonitorStatus(record.monitor_status.value), model_id=ModelID(record.model_id), drift_alert=DriftAlert(record.drift_alert.value), drift_metric=DriftMetric(record.drift_metric.value), window_seconds=record.window_seconds, current_workflow=record.current_workflow, last_run=record.last_run if record.last_run is not None else None, minimum_cardinality=record.minimum_cardinality if record.minimum_cardinality is not None else None, maximum_cardinality=record.maximum_cardinality if record.maximum_cardinality is not None else None, time_create=record.time_create, time_modify=record.time_modify, time_delete=record.time_delete if record.time_delete is not None else None, )
[docs] def monitor_get(model_id: ModelID) -> list[Monitor]: api = _apis.model_drift.monitor_api model_id = str(model_id) monitor_result = api.get_monitors_model_drift_v1_monitors_model_id_get(model_id) return list(map(_monitor_from_record, monitor_result.data))
[docs] def monitor_update(model_id: ModelID, metric: DriftMetric, spec: MonitorUpdateSpec) -> None: update = model_drift.MonitorUpdate( monitor_status=model_drift.MonitorStatus(spec.monitor_status.value) if spec.monitor_status is not None else None, drift_alert=model_drift.DriftAlert(spec.drift_alert.value) if spec.drift_alert is not None else None, drift_metric=model_drift.DriftMetric(spec.drift_metric) if spec.drift_metric is not None else None, window_seconds=spec.window_seconds if spec.window_seconds is not None else None, current_workflow=model_drift.CurrentWorkflow(spec.current_workflow) if spec.current_workflow is not None else None, last_run=model_drift.LastRun(spec.last_run) if spec.last_run is not None else None, minimum_cardinality=spec.minimum_cardinality if spec.minimum_cardinality is not None else None, maximum_cardinality=spec.maximum_cardinality if spec.maximum_cardinality is not None else None, ) api = _apis.model_drift.monitor_api api.patch_monitors_model_drift_v1_monitors_model_id_metric_patch( model_id, metric.value, update )
def _check_from_record(record: model_drift.CheckRecord) -> Check: return Check( check_id=CheckID(record.check_id), model_id=ModelID(record.model_id), check_status=CheckStatus(record.check_status.value), window_begin=record.window_begin, window_end=record.window_end, metric=DriftMetric(record.metric.value), attributes=record.attributes.to_dict() if record.attributes is not None else dict(), time_create=record.time_create, time_modify=record.time_modify, time_delete=record.time_delete.actual_instance if record.time_delete is not None else None, )
[docs] def check_latest(model_id: ModelID) -> list[Check]: api = _apis.model_drift.check_api model_id = str(model_id) check_result = api.get_latest_checks_model_drift_v1_checks_model_id_latest_get( model_id, include_logs=False ) return list(map(_check_from_record, check_result.data))
[docs] def check_window(model_id: ModelID, metric: DriftMetric, time_begin: datetime, time_end: datetime) -> list[Check]: api = _apis.model_drift.check_api check_result = api.get_check_window_model_drift_v1_checks_model_id_metric_get( model_id, metric.value, time_begin, time_end, include_logs=False ) return list(map(_check_from_record, check_result.data))
[docs] def check_summary(model_id: ModelID, metric: DriftMetric, time_begin: datetime, time_end: datetime) -> list[CheckPoint]: api = _apis.model_drift.check_api check_result = api.get_check_summary_model_drift_v1_checks_model_id_metric_summary_get_with_http_info( model_id, metric.value, time_begin, time_end ) return list(map(_check_from_record, check_result.data))
[docs] def check_now(model_id: ModelID, metric: DriftMetric) -> CheckJob: api = _apis.model_drift.drift_api run = api.post_drift_check_model_drift_v1_drift_check_post( model_drift.DriftCheckArgs( model_id=str(model_id), metric=model_drift.DriftMetric(metric.value) ) ) return CheckJob(job_id=JobID(run.job_id))