from collections.abc import Generator, Iterator
from dataclasses import asdict
from datetime import datetime
from chariot import _apis
from chariot.datasets import _utils, models
from chariot_api._openapi.datasets_v3 import models as openapi_models
__all__ = [
"get_datasets",
"create_dataset",
"get_dataset_statistics",
"get_dataset",
"update_dataset",
"delete_dataset",
"get_dataset_timeline",
"get_authorized_dataset_ids",
"create_dataset_timeline_description",
]
[docs]
def get_datasets(
*,
exact_name_match: bool | None = None,
exclude_unlabeled: bool | None = None,
limit_to_write_access: bool | None = None,
name: str | None = None,
project_ids: list[str] | None = None,
dataset_ids: list[str] | None = None,
task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None,
type: models.DatasetType | None = None,
sort: models.DatasetSortColumn | None = None,
direction: models.SortDirection | None = None,
max_items: int | None = None,
) -> Generator[models.Dataset, None, None]:
"""Get datasets with various criteria. Returns a generator over all matching datasets.
:param exact_name_match: Require name filter to match exactly (defaults to false)
:type exact_name_match: Optional[bool]
:param exclude_unlabeled: Should unlabeled datasets be included (defaults to false)
:type exclude_unlabeled: Optional[bool]
:param limit_to_write_access: Should the results only include datasets that the user has write access to (defaults to false)
:type limit_to_write_access: Optional[bool]
:param name: Filter by dataset name
:type name: Optional[str]
:param project_ids: Filter by project ids
:type project_ids: Optional[List[str]]
:param dataset_ids: Filter by dataset ids
:type dataset_ids: Optional[List[str]]
:param task_type_label_filters: Filter by task types and associated labels
:type task_type_label_filters: Optional[List[models.TaskTypeLabelFilter]]
:param type: Filter by dataset type
:type type: Optional[models.DatasetType]
:param sort: How to sort the returned datasets
:type sort: Optional[models.DatasetSortColumn]
:param direction: Whether to sort in ascending or descending order
:type direction: Optional[models.SortDirection]
:param max_items: Limit the returned generator to only produce this many items
:type max_items: Optional[int]
:return: Dataset details for datasets matching the criteria
:rtype: Generator[models.Dataset, None, None]
"""
params = locals()
if "max_items" in params:
del params["max_items"]
if dataset_ids:
return iter(
_get_datasets(
exact_name_match=exact_name_match,
exclude_unlabeled=exclude_unlabeled,
limit_to_write_access=limit_to_write_access,
name=name,
project_ids=project_ids,
dataset_ids=dataset_ids,
task_type_label_filters=task_type_label_filters,
type=type,
sort=sort,
direction=direction,
)
)
else:
return _utils.paginate_items(_get_datasets, params, max_items)
def _get_datasets(
*,
exact_name_match: bool | None = None,
exclude_unlabeled: bool | None = None,
limit_to_write_access: bool | None = None,
name: str | None = None,
project_ids: list[str] | None = None,
dataset_ids: list[str] | None = None,
task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None,
type: models.DatasetType | None = None,
limit: int | None = None,
offset: int | None = None,
sort: models.DatasetSortColumn | None = None,
direction: models.SortDirection | None = None,
) -> list[models.Dataset]:
if task_type_label_filters is not None:
task_type_label_filters = [
asdict(f, dict_factory=_utils.dict_factory) for f in task_type_label_filters
]
request = openapi_models.InputGetDatasetsRequest(
exact_name_match=exact_name_match,
exclude_unlabeled=exclude_unlabeled,
limit_to_write_access=limit_to_write_access,
name=name,
project_ids=project_ids,
dataset_ids=dataset_ids,
task_type_label_filters=task_type_label_filters,
dataset_type=_utils.enum_value(type),
limit=limit,
offset=offset,
sort=_utils.enum_value(sort),
direction=_utils.enum_value(direction),
)
response = _apis.datasets_v3.datasets_api.get_datasets(body=request)
if not response.data:
return []
return [_utils.convert_to_dataclass(d.model_dump(), models.Dataset) for d in response.data]
[docs]
def create_dataset(
*,
name: str,
type: models.DatasetType,
project_id: str,
description: str | None = None,
is_public: bool | None = None,
_is_test: bool | None = None,
) -> models.Dataset:
"""Create a new, empty dataset
:param name: Dataset name
:type name: str
:param type: Dataset type
:type type: models.DatasetType
:param project_id: Project id to create the dataset in
:type project_id: str
:param description: Dataset description
:type description: Optional[str]
:param is_public: When set to true, the dataset will be publically accessible.
:type is_public: Optional[bool]
:return: New dataset details
:rtype: models.Dataset
"""
request = openapi_models.InputCreateDatasetRequest(
name=name,
type=_utils.enum_value(type),
project_id=project_id,
description=description,
is_public=is_public,
is_test=_is_test,
)
response = _apis.datasets_v3.datasets_api.create_dataset(body=request)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from create_dataset")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs]
def get_dataset_statistics(
*,
exact_name_match: bool | None = None,
exclude_unlabeled: bool | None = None,
limit_to_write_access: bool | None = None,
name: str | None = None,
project_ids: list[str] | None = None,
dataset_ids: list[str] | None = None,
task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None,
type: models.DatasetType | None = None,
) -> models.DatasetStatistics:
"""Get dataset statistics with various criteria.
:param exact_name_match: Require name filter to match exactly (defaults to false)
:type exact_name_match: Optional[bool]
:param exclude_unlabeled: Should unlabeled datasets be included (defaults to false)
:type exclude_unlabeled: Optional[bool]
:param limit_to_write_access: Should the results only include datasets that the user has write access to (defaults to false)
:type limit_to_write_access: Optional[bool]
:param name: Filter by dataset name
:type name: Optional[str]
:param project_ids: Filter by project ids
:type project_ids: Optional[List[str]]
:param dataset_ids: Filter by dataset ids
:type dataset_ids: Optional[List[str]]
:param task_type_label_filters: Filter by task types and associated labels
:type task_type_label_filters: Optional[List[models.TaskTypeLabelFilter]]
:param type: Filter by dataset type
:type type: Optional[models.DatasetType]
:return: Statistics for the datasets
:rtype: models.DatasetStatistics
"""
if task_type_label_filters is not None:
task_type_label_filters = [
asdict(f, dict_factory=_utils.dict_factory) for f in task_type_label_filters
]
request = openapi_models.ModelDatasetFilter(
exact_name_match=exact_name_match,
exclude_unlabeled=exclude_unlabeled,
limit_to_write_access=limit_to_write_access,
name=name,
project_ids=project_ids,
dataset_ids=dataset_ids,
task_type_label_filters=task_type_label_filters,
dataset_type=type,
)
response = _apis.datasets_v3.datasets_api.get_dataset_statistics(body=request)
if not response.data:
raise RuntimeError(
"Received malformed response (missing `data`) from get_dataset_statistics"
)
return _utils.convert_to_dataclass(response.data.model_dump(), models.DatasetStatistics)
[docs]
def get_dataset(id: str) -> models.Dataset:
"""Get a dataset by id
:param id: Dataset id
:type id: str
:return: Dataset details
:rtype: models.Dataset
"""
response = _apis.datasets_v3.datasets_api.get_dataset(dataset_id=id)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from get_dataset")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs]
def update_dataset(
id: str,
*,
name: str | None = None,
description: str | None = None,
) -> models.Dataset:
"""Update a dataset's name or description
:param id: Dataset id to update
:type id: str
:param name: New name for the dataset. Name remains unmodified if set to None
:type name: Optional[str]
:param description: New description for the dataset. Description remains unmodified if set to None
:type description: Optional[str]
:return: Updated dataset details
:rtype: models.Dataset
"""
request = openapi_models.InputUpdateDatasetRequest(
name=name,
description=description,
)
response = _apis.datasets_v3.datasets_api.update_dataset(dataset_id=id, body=request)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from update_dataset")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs]
def delete_dataset(id: str) -> models.Dataset:
"""Delete a dataset by id. The artifacts for the dataset will be deleted
:param id: Id of dataset to delete
:type id: str
:return: Deleted dataset details
:rtype: models.Dataset
"""
response = _apis.datasets_v3.datasets_api.archive_dataset(dataset_id=id)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from archive_dataset")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs]
def get_dataset_timeline(
id: str,
*,
max_items: int | None = None,
direction: models.SortDirection | None = None,
min_groups: int | None = None,
max_ungrouped_events: int | None = None,
) -> Iterator[models.DatasetTimelineEvent]:
"""Get a series of dataset change events ordered by time and grouped by event type.
:param id: Dataset id to get events for
:type id: str
:param max_items: Limit the returned generator to only produce this many items
:type max_items: Optional[int]
:param direction: Whether to sort in ascending or descending order
:type direction: Optional[models.SortDirection]
:param min_groups: How many groups are required before grouping behavior is turned on
:type min_groups: Optional[int]
:param max_ungrouped_events: The maximum number of events allowed before grouping behavior is turned on
:type max_ungrouped_events: Optional[int]
:return: Events for the dataset
:rtype: Iterator[models.DatasetTimelineEvent]
"""
if "upload_filter" in locals():
raise TypeError(
"get_dataset_timeline does not support 'upload_filter' - use 'get_view_timeline' with 'since_last_snapshot=True' to get view events since the last view snapshot"
)
response = _apis.datasets_v3.datasets_api.get_dataset_timeline(
dataset_id=id,
limit=max_items,
direction=_utils.enum_value(direction),
min_groups=min_groups,
max_ungrouped_events=max_ungrouped_events,
)
data = response.data or []
return (_utils.convert_to_dataclass(d.model_dump(), models.DatasetTimelineEvent) for d in data)
[docs]
def get_authorized_dataset_ids(ids: list[str]) -> list[str]:
"""Given a list of Dataset Ids, return ids from the list that the user has read access to
:param ids: List of dataset ids to check
:type ids: List[str]
"""
response = _apis.datasets_v3.datasets_api.get_authorized_dataset_ids(body=ids)
return response.data or []
[docs]
def create_dataset_timeline_description(id: str, description: str, timestamp: datetime):
"""Adds a user-defined description event for a particular timeline event group.
Parameters
----------
id : str
Id of the dataset.
description : str
Description of the timeline event group. Must be less than 200 characters.
timestamp : datetime
Timestamp representing the event time of the group leader to which this description will be added
Raises
------
APIException
If api communication fails, request is
unauthorized or is unauthenticated.
"""
request = openapi_models.InputCreateDatasetTimelineEventRequest(
description=description, timestamp=_utils.format_datetime_utc(timestamp)
)
_apis.datasets_v3.datasets_api.create_dataset_timeline_event(dataset_id=id, body=request)