import json
import time
import numpy as np
from tqdm import tqdm
from chariot.bulk_inference import (
NewBulkInferenceJobRequest,
ProgressState,
create_bulk_inference_job,
get_bulk_inference_job,
get_bulk_inference_jobs,
)
from chariot.datasets import get_snapshot
from chariot.inference_store import (
BaseInferenceFilter,
MetadataFilter,
MetadataFilterOperator,
MetadataFilterType,
NewGetInferencesRequest,
filter_inferences,
)
from chariot.models import Model, get_models
from chariot.projects import get_organization_id, get_project_id
__all__ = [
"get_chariot_common_embedding_model",
"embed",
]
[docs]
def get_chariot_common_embedding_model(timeout: int = 120) -> Model:
"""Get the global embedding model.
:param timeout: Amount of time to wait for inference server start up. Defaults to 120.
:type timeout: int
:return: The embedding model.
:rtype: np.ndarray
"""
common_project_id = get_project_id(
project_name="Common",
organization_id=get_organization_id("Chariot"),
)
embedding_models = get_models(
project_id=common_project_id,
model_name="model-recommender-embed",
)
if len(embedding_models) != 1:
raise ValueError(f"expected one model, received {len(embedding_models)}")
embedding_model = embedding_models[0]
embedding_model.wait_for_inference_server(timeout=timeout)
return embedding_model
[docs]
def embed(
snapshot_id: str,
split: str,
verbose: bool = False,
) -> np.ndarray:
"""Create embeddings for a dataset snapshot.
:param snapshot_id: The dataset snapshot id.
:type snapshot_id: str
:param split: Dataset snapshot split.
:type split: str
:param verbose: Toggles verbose output. Defaults to False.
:type verbose: bool
:return: NumPy array with shape (n_embeddings, n_features).
:rtype: np.ndarray
"""
snapshot = get_snapshot(snapshot_id)
dataset = snapshot.view.dataset
if not dataset:
raise ValueError("snapshot does not contain a reference dataset")
embedding_model = get_chariot_common_embedding_model()
existing_jobs = get_bulk_inference_jobs(
model_id=embedding_model.id,
dataset_version_split=split,
dataset_snapshot_id=snapshot.id,
)
# create a new job if no successful job already exists
if not any([job.execution_status == ProgressState.SUCCEEDED for job in existing_jobs]):
job_spec = NewBulkInferenceJobRequest(
model_project_id=embedding_model.project_id,
model_id=embedding_model.id,
dataset_id=dataset.id,
dataset_project_id=dataset.project_id,
dataset_snapshot_id=snapshot.id,
dataset_snapshot_split=split,
evaluate_metrics=False,
inference_method="embed",
)
job_id = create_bulk_inference_job(job_spec=job_spec)
if not job_id:
raise RuntimeError("bulk inference job did not return an ID")
latest = get_bulk_inference_job(job_id=job_id)
if not latest:
raise RuntimeError(f"could not find bulk inference job with id '{job_id}'")
status = latest.execution_status
total = latest.expected_total
pbar = tqdm(total=total, desc=f"Running embedding job '{job_id}'", disable=(not verbose))
while status != ProgressState.SUCCEEDED:
latest = get_bulk_inference_job(job_id=job_id)
if not latest:
raise RuntimeError(f"could not find bulk inference job with id '{job_id}'")
status = latest.execution_status
current = latest.num_computed
if status in {
ProgressState.FAILED,
ProgressState.RUNNING_FAILED,
ProgressState.ABORTED,
}:
raise RuntimeError(
f"bulk inference job '{job_id}' failed with status {latest.execution_status}"
)
pbar.n = current
pbar.refresh()
time.sleep(5)
pbar.close()
new_job = get_bulk_inference_job(job_id)
if not new_job:
raise RuntimeError(f"could not find bulk inference job with id '{job_id}'")
existing_jobs.append(new_job)
embeddings = []
total_embeddings = existing_jobs[0].num_computed
pbar = tqdm(
total=total_embeddings,
desc="Loading embeddings from inference store",
disable=(not verbose),
)
for job in existing_jobs:
job_id = job.execution_id
inferences = filter_inferences(
model_id=embedding_model.id,
request_body=NewGetInferencesRequest(
filters=BaseInferenceFilter(
inference_action_filter="embed",
metadata_filter=[
MetadataFilter(
key="bulk_inference_job_id",
operator=MetadataFilterOperator.EQUAL,
type=MetadataFilterType.STRING,
value=job_id,
)
],
)
),
)
embeddings.extend(inferences)
pbar.update(len(inferences))
pbar.close()
return np.array(
[
json.loads(embedding.data)
for embedding in embeddings
if embedding.data and len(json.loads(embedding.data)) == 2048
]
)