Source code for chariot.drift.embed

import json
import time

import numpy as np
from tqdm import tqdm

from chariot.bulk_inference import (
    NewBulkInferenceJobRequest,
    ProgressState,
    create_bulk_inference_job,
    get_bulk_inference_job,
    get_bulk_inference_jobs,
)
from chariot.datasets import get_snapshot
from chariot.inference_store import (
    BaseInferenceFilter,
    MetadataFilter,
    MetadataFilterOperator,
    MetadataFilterType,
    NewGetInferencesRequest,
    filter_inferences,
)
from chariot.models import Model, get_models
from chariot.projects import get_organization_id, get_project_id

__all__ = [
    "get_chariot_common_embedding_model",
    "embed",
]


[docs] def get_chariot_common_embedding_model(timeout: int = 120) -> Model: """Get the global embedding model. :param timeout: Amount of time to wait for inference server start up. Defaults to 120. :type timeout: int :return: The embedding model. :rtype: np.ndarray """ common_project_id = get_project_id( project_name="Common", organization_id=get_organization_id("Chariot"), ) embedding_models = get_models( project_id=common_project_id, model_name="model-recommender-embed", ) if len(embedding_models) != 1: raise ValueError(f"expected one model, received {len(embedding_models)}") embedding_model = embedding_models[0] embedding_model.wait_for_inference_server(timeout=timeout) return embedding_model
[docs] def embed( snapshot_id: str, split: str, verbose: bool = False, ) -> np.ndarray: """Create embeddings for a dataset snapshot. :param snapshot_id: The dataset snapshot id. :type snapshot_id: str :param split: Dataset snapshot split. :type split: str :param verbose: Toggles verbose output. Defaults to False. :type verbose: bool :return: NumPy array with shape (n_embeddings, n_features). :rtype: np.ndarray """ snapshot = get_snapshot(snapshot_id) dataset = snapshot.view.dataset if not dataset: raise ValueError("snapshot does not contain a reference dataset") embedding_model = get_chariot_common_embedding_model() existing_jobs = get_bulk_inference_jobs( model_id=embedding_model.id, dataset_version_split=split, dataset_snapshot_id=snapshot.id, ) # create a new job if no successful job already exists if not any([job.execution_status == ProgressState.SUCCEEDED for job in existing_jobs]): job_spec = NewBulkInferenceJobRequest( model_project_id=embedding_model.project_id, model_id=embedding_model.id, dataset_id=dataset.id, dataset_project_id=dataset.project_id, dataset_snapshot_id=snapshot.id, dataset_snapshot_split=split, evaluate_metrics=False, inference_method="embed", ) job_id = create_bulk_inference_job(job_spec=job_spec) if not job_id: raise RuntimeError("bulk inference job did not return an ID") latest = get_bulk_inference_job(job_id=job_id) if not latest: raise RuntimeError(f"could not find bulk inference job with id '{job_id}'") status = latest.execution_status total = latest.expected_total pbar = tqdm(total=total, desc=f"Running embedding job '{job_id}'", disable=(not verbose)) while status != ProgressState.SUCCEEDED: latest = get_bulk_inference_job(job_id=job_id) if not latest: raise RuntimeError(f"could not find bulk inference job with id '{job_id}'") status = latest.execution_status current = latest.num_computed if status in { ProgressState.FAILED, ProgressState.RUNNING_FAILED, ProgressState.ABORTED, }: raise RuntimeError( f"bulk inference job '{job_id}' failed with status {latest.execution_status}" ) pbar.n = current pbar.refresh() time.sleep(5) pbar.close() new_job = get_bulk_inference_job(job_id) if not new_job: raise RuntimeError(f"could not find bulk inference job with id '{job_id}'") existing_jobs.append(new_job) embeddings = [] total_embeddings = existing_jobs[0].num_computed pbar = tqdm( total=total_embeddings, desc="Loading embeddings from inference store", disable=(not verbose), ) for job in existing_jobs: job_id = job.execution_id inferences = filter_inferences( model_id=embedding_model.id, request_body=NewGetInferencesRequest( filters=BaseInferenceFilter( inference_action_filter="embed", metadata_filter=[ MetadataFilter( key="bulk_inference_job_id", operator=MetadataFilterOperator.EQUAL, type=MetadataFilterType.STRING, value=job_id, ) ], ) ), ) embeddings.extend(inferences) pbar.update(len(inferences)) pbar.close() return np.array( [ json.loads(embedding.data) for embedding in embeddings if embedding.data and len(json.loads(embedding.data)) == 2048 ] )