Source code for chariot.drift.embed

import json
import time

import numpy as np

from chariot.bulk_inference import (
    BulkInferenceJob,
    ProgressState,
    get_bulk_inference_jobs,
)
from chariot.datasets import get_snapshot, get_snapshot_datum_count
from chariot.drift.representer import (
    EmbeddingJobSpec,
    EmbeddingJobState,
    embedding_job_dataset,
    embedding_job_run,
    embedding_job_status,
)
from chariot.inference_store import (
    BaseInferenceFilter,
    MetadataFilter,
    MetadataFilterOperator,
    MetadataFilterType,
    NewGetInferencesRequest,
    Pagination,
    filter_inferences,
)
from chariot.models import Model, get_models
from chariot.projects import get_organization_id, get_project_id

__all__ = [
    "get_chariot_common_embedding_model",
    "embed",
]


COMMON_EMBEDDING_DIMENSION = 2048


[docs] def get_chariot_common_embedding_model(timeout: int = 120) -> Model: """Get the global embedding model. :param timeout: Amount of time to wait for inference server start up. Defaults to 120. :type timeout: int :return: The embedding model. :rtype: np.ndarray """ common_project_id = get_project_id( project_name="Common", organization_id=get_organization_id("Chariot"), ) embedding_models = get_models( project_id=common_project_id, model_name="model-recommender-embed", ) if len(embedding_models) != 1: raise ValueError(f"expected one model, received {len(embedding_models)}") embedding_model = embedding_models[0] return embedding_model
def _get_successful_bulk_inference_embedding_job( embedding_model_id: str, snapshot_id: str, split: str, ) -> BulkInferenceJob | None: jobs = [ job for job in get_bulk_inference_jobs( model_id=embedding_model_id, dataset_snapshot_id=snapshot_id, dataset_version_split=split, ) if job.execution_status == ProgressState.SUCCEEDED ] if not jobs: return None return jobs[0]
[docs] def embed( snapshot_id: str, split: str, limit: int = 1000, ) -> np.ndarray: """Create embeddings for a dataset snapshot. :param snapshot_id: The dataset snapshot id. :type snapshot_id: str :param split: Dataset snapshot split. :type split: str :param limit: Pagination limit. Defaults to 1000. :type verbose: int :return: NumPy array with shape (n_embeddings, n_features). :rtype: np.ndarray """ embedding_model = get_chariot_common_embedding_model() snapshot = get_snapshot(snapshot_id) dataset = snapshot.view.dataset if not dataset: raise ValueError("snapshot does not contain a reference dataset") # create representer embedding job if one does not exist rep_job = embedding_job_dataset(snapshot_id=snapshot_id, split=split) if rep_job is None: spec = EmbeddingJobSpec( snapshot_id=snapshot_id, split=split, ) rep_job = embedding_job_run(spec) while rep_job.job_state != EmbeddingJobState.COMPLETED: rep_job = embedding_job_status(rep_job.generated_id) assert rep_job if rep_job.job_state in { EmbeddingJobState.ABORTED, EmbeddingJobState.FAILED, }: raise RuntimeError( f"embedding job '{rep_job.generated_id}' failed with status '{rep_job.job_state}'" ) time.sleep(3) # find bulk inference job bi_job = _get_successful_bulk_inference_embedding_job( embedding_model_id=embedding_model.id, snapshot_id=snapshot_id, split=split, ) if bi_job is None: raise RuntimeError("could not find successful bulk inference job") embedding_pages = [] offset = 0 while True: inferences = filter_inferences( model_id=embedding_model.id, request_body=NewGetInferencesRequest( filters=BaseInferenceFilter( inference_action_filter="embed", metadata_filter=[ MetadataFilter( key="bulk_inference_job_id", operator=MetadataFilterOperator.EQUAL, type=MetadataFilterType.STRING, value=bi_job.execution_id, ) ], ), pagination=Pagination(limit=limit, offset=offset), ), ) offset += limit if not inferences: break embedding_pages.append( np.array( [ json.loads(inference.data) for inference in inferences if inference.data and len(json.loads(inference.data)) == COMMON_EMBEDDING_DIMENSION ] ) ) embeddings = np.concatenate(embedding_pages, axis=0) num_embeddings = embeddings.shape[0] num_snapshot_datums = get_snapshot_datum_count(snapshot_id=snapshot_id, split=split) if num_embeddings != num_snapshot_datums: raise RuntimeError( f"number of embeddings does not match dataset size: {num_embeddings} != {num_snapshot_datums}" ) return embeddings