import json
import time
import numpy as np
from chariot.bulk_inference import (
BulkInferenceJob,
ProgressState,
get_bulk_inference_jobs,
)
from chariot.datasets import get_snapshot, get_snapshot_datum_count
from chariot.drift.representer import (
EmbeddingJobSpec,
EmbeddingJobState,
embedding_job_dataset,
embedding_job_run,
embedding_job_status,
)
from chariot.inference_store import (
BaseInferenceFilter,
MetadataFilter,
MetadataFilterOperator,
MetadataFilterType,
NewGetInferencesRequest,
Pagination,
filter_inferences,
)
from chariot.models import Model, get_models
from chariot.projects import get_organization_id, get_project_id
__all__ = [
"get_chariot_common_embedding_model",
"embed",
]
COMMON_EMBEDDING_DIMENSION = 2048
[docs]
def get_chariot_common_embedding_model(timeout: int = 120) -> Model:
"""Get the global embedding model.
:param timeout: Amount of time to wait for inference server start up. Defaults to 120.
:type timeout: int
:return: The embedding model.
:rtype: np.ndarray
"""
common_project_id = get_project_id(
project_name="Common",
organization_id=get_organization_id("Chariot"),
)
embedding_models = get_models(
project_id=common_project_id,
model_name="model-recommender-embed",
)
if len(embedding_models) != 1:
raise ValueError(f"expected one model, received {len(embedding_models)}")
embedding_model = embedding_models[0]
return embedding_model
def _get_successful_bulk_inference_embedding_job(
embedding_model_id: str,
snapshot_id: str,
split: str,
) -> BulkInferenceJob | None:
jobs = [
job
for job in get_bulk_inference_jobs(
model_id=embedding_model_id,
dataset_snapshot_id=snapshot_id,
dataset_version_split=split,
)
if job.execution_status == ProgressState.SUCCEEDED
]
if not jobs:
return None
return jobs[0]
[docs]
def embed(
snapshot_id: str,
split: str,
limit: int = 1000,
) -> np.ndarray:
"""Create embeddings for a dataset snapshot.
:param snapshot_id: The dataset snapshot id.
:type snapshot_id: str
:param split: Dataset snapshot split.
:type split: str
:param limit: Pagination limit. Defaults to 1000.
:type verbose: int
:return: NumPy array with shape (n_embeddings, n_features).
:rtype: np.ndarray
"""
embedding_model = get_chariot_common_embedding_model()
snapshot = get_snapshot(snapshot_id)
dataset = snapshot.view.dataset
if not dataset:
raise ValueError("snapshot does not contain a reference dataset")
# create representer embedding job if one does not exist
rep_job = embedding_job_dataset(snapshot_id=snapshot_id, split=split)
if rep_job is None:
spec = EmbeddingJobSpec(
snapshot_id=snapshot_id,
split=split,
)
rep_job = embedding_job_run(spec)
while rep_job.job_state != EmbeddingJobState.COMPLETED:
rep_job = embedding_job_status(rep_job.generated_id)
assert rep_job
if rep_job.job_state in {
EmbeddingJobState.ABORTED,
EmbeddingJobState.FAILED,
}:
raise RuntimeError(
f"embedding job '{rep_job.generated_id}' failed with status '{rep_job.job_state}'"
)
time.sleep(3)
# find bulk inference job
bi_job = _get_successful_bulk_inference_embedding_job(
embedding_model_id=embedding_model.id,
snapshot_id=snapshot_id,
split=split,
)
if bi_job is None:
raise RuntimeError("could not find successful bulk inference job")
embedding_pages = []
offset = 0
while True:
inferences = filter_inferences(
model_id=embedding_model.id,
request_body=NewGetInferencesRequest(
filters=BaseInferenceFilter(
inference_action_filter="embed",
metadata_filter=[
MetadataFilter(
key="bulk_inference_job_id",
operator=MetadataFilterOperator.EQUAL,
type=MetadataFilterType.STRING,
value=bi_job.execution_id,
)
],
),
pagination=Pagination(limit=limit, offset=offset),
),
)
offset += limit
if not inferences:
break
embedding_pages.append(
np.array(
[
json.loads(inference.data)
for inference in inferences
if inference.data
and len(json.loads(inference.data)) == COMMON_EMBEDDING_DIMENSION
]
)
)
embeddings = np.concatenate(embedding_pages, axis=0)
num_embeddings = embeddings.shape[0]
num_snapshot_datums = get_snapshot_datum_count(snapshot_id=snapshot_id, split=split)
if num_embeddings != num_snapshot_datums:
raise RuntimeError(
f"number of embeddings does not match dataset size: {num_embeddings} != {num_snapshot_datums}"
)
return embeddings