Source code for chariot.datasets.datasets

from collections.abc import Generator, Iterator
from dataclasses import asdict
from datetime import datetime

from chariot import _apis
from chariot.datasets import _utils, models
from chariot_api._openapi.datasets_v3 import models as openapi_models

__all__ = [
    "get_datasets",
    "create_dataset",
    "get_dataset_statistics",
    "get_dataset",
    "update_dataset",
    "delete_dataset",
    "get_dataset_timeline",
    "get_authorized_dataset_ids",
    "create_dataset_timeline_description",
]


[docs] def get_datasets( *, exact_name_match: bool | None = None, exclude_unlabeled: bool | None = None, limit_to_write_access: bool | None = None, name: str | None = None, project_ids: list[str] | None = None, dataset_ids: list[str] | None = None, task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None, type: models.DatasetType | None = None, sort: models.DatasetSortColumn | None = None, direction: models.SortDirection | None = None, max_items: int | None = None, ) -> Generator[models.Dataset, None, None]: """Get datasets with various criteria. Returns a generator over all matching datasets. :param exact_name_match: Require name filter to match exactly (defaults to false) :type exact_name_match: Optional[bool] :param exclude_unlabeled: Should unlabeled datasets be included (defaults to false) :type exclude_unlabeled: Optional[bool] :param limit_to_write_access: Should the results only include datasets that the user has write access to (defaults to false) :type limit_to_write_access: Optional[bool] :param name: Filter by dataset name :type name: Optional[str] :param project_ids: Filter by project ids :type project_ids: Optional[List[str]] :param dataset_ids: Filter by dataset ids :type dataset_ids: Optional[List[str]] :param task_type_label_filters: Filter by task types and associated labels :type task_type_label_filters: Optional[List[models.TaskTypeLabelFilter]] :param type: Filter by dataset type :type type: Optional[models.DatasetType] :param sort: How to sort the returned datasets :type sort: Optional[models.DatasetSortColumn] :param direction: Whether to sort in ascending or descending order :type direction: Optional[models.SortDirection] :param max_items: Limit the returned generator to only produce this many items :type max_items: Optional[int] :return: Dataset details for datasets matching the criteria :rtype: Generator[models.Dataset, None, None] """ params = locals() if "max_items" in params: del params["max_items"] if dataset_ids: return iter( _get_datasets( exact_name_match=exact_name_match, exclude_unlabeled=exclude_unlabeled, limit_to_write_access=limit_to_write_access, name=name, project_ids=project_ids, dataset_ids=dataset_ids, task_type_label_filters=task_type_label_filters, type=type, sort=sort, direction=direction, ) ) else: return _utils.paginate_items(_get_datasets, params, max_items)
def _get_datasets( *, exact_name_match: bool | None = None, exclude_unlabeled: bool | None = None, limit_to_write_access: bool | None = None, name: str | None = None, project_ids: list[str] | None = None, dataset_ids: list[str] | None = None, task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None, type: models.DatasetType | None = None, limit: int | None = None, offset: int | None = None, sort: models.DatasetSortColumn | None = None, direction: models.SortDirection | None = None, ) -> list[models.Dataset]: if task_type_label_filters is not None: task_type_label_filters = [ asdict(f, dict_factory=_utils.dict_factory) for f in task_type_label_filters ] request = openapi_models.InputGetDatasetsRequest( exact_name_match=exact_name_match, exclude_unlabeled=exclude_unlabeled, limit_to_write_access=limit_to_write_access, name=name, project_ids=project_ids, dataset_ids=dataset_ids, task_type_label_filters=task_type_label_filters, dataset_type=_utils.enum_value(type), limit=limit, offset=offset, sort=_utils.enum_value(sort), direction=_utils.enum_value(direction), ) response = _apis.datasets_v3.datasets_api.get_datasets(body=request) if not response.data: return [] return [_utils.convert_to_dataclass(d.model_dump(), models.Dataset) for d in response.data]
[docs] def create_dataset( *, name: str, type: models.DatasetType, project_id: str, description: str | None = None, is_public: bool | None = None, _is_test: bool | None = None, ) -> models.Dataset: """Create a new, empty dataset :param name: Dataset name :type name: str :param type: Dataset type :type type: models.DatasetType :param project_id: Project id to create the dataset in :type project_id: str :param description: Dataset description :type description: Optional[str] :param is_public: When set to true, the dataset will be publically accessible. :type is_public: Optional[bool] :return: New dataset details :rtype: models.Dataset """ request = openapi_models.InputCreateDatasetRequest( name=name, type=_utils.enum_value(type), project_id=project_id, description=description, is_public=is_public, is_test=_is_test, ) response = _apis.datasets_v3.datasets_api.create_dataset(body=request) if not response.data: raise RuntimeError("Received malformed response (missing `data`) from create_dataset") return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs] def get_dataset_statistics( *, exact_name_match: bool | None = None, exclude_unlabeled: bool | None = None, limit_to_write_access: bool | None = None, name: str | None = None, project_ids: list[str] | None = None, dataset_ids: list[str] | None = None, task_type_label_filters: list[models.TaskTypeLabelFilter] | None = None, type: models.DatasetType | None = None, ) -> models.DatasetStatistics: """Get dataset statistics with various criteria. :param exact_name_match: Require name filter to match exactly (defaults to false) :type exact_name_match: Optional[bool] :param exclude_unlabeled: Should unlabeled datasets be included (defaults to false) :type exclude_unlabeled: Optional[bool] :param limit_to_write_access: Should the results only include datasets that the user has write access to (defaults to false) :type limit_to_write_access: Optional[bool] :param name: Filter by dataset name :type name: Optional[str] :param project_ids: Filter by project ids :type project_ids: Optional[List[str]] :param dataset_ids: Filter by dataset ids :type dataset_ids: Optional[List[str]] :param task_type_label_filters: Filter by task types and associated labels :type task_type_label_filters: Optional[List[models.TaskTypeLabelFilter]] :param type: Filter by dataset type :type type: Optional[models.DatasetType] :return: Statistics for the datasets :rtype: models.DatasetStatistics """ if task_type_label_filters is not None: task_type_label_filters = [ asdict(f, dict_factory=_utils.dict_factory) for f in task_type_label_filters ] request = openapi_models.ModelDatasetFilter( exact_name_match=exact_name_match, exclude_unlabeled=exclude_unlabeled, limit_to_write_access=limit_to_write_access, name=name, project_ids=project_ids, dataset_ids=dataset_ids, task_type_label_filters=task_type_label_filters, dataset_type=type, ) response = _apis.datasets_v3.datasets_api.get_dataset_statistics(body=request) if not response.data: raise RuntimeError( "Received malformed response (missing `data`) from get_dataset_statistics" ) return _utils.convert_to_dataclass(response.data.model_dump(), models.DatasetStatistics)
[docs] def get_dataset(id: str) -> models.Dataset: """Get a dataset by id :param id: Dataset id :type id: str :return: Dataset details :rtype: models.Dataset """ response = _apis.datasets_v3.datasets_api.get_dataset(dataset_id=id) if not response.data: raise RuntimeError("Received malformed response (missing `data`) from get_dataset") return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs] def update_dataset( id: str, *, name: str | None = None, description: str | None = None, ) -> models.Dataset: """Update a dataset's name or description :param id: Dataset id to update :type id: str :param name: New name for the dataset. Name remains unmodified if set to None :type name: Optional[str] :param description: New description for the dataset. Description remains unmodified if set to None :type description: Optional[str] :return: Updated dataset details :rtype: models.Dataset """ request = openapi_models.InputUpdateDatasetRequest( name=name, description=description, ) response = _apis.datasets_v3.datasets_api.update_dataset(dataset_id=id, body=request) if not response.data: raise RuntimeError("Received malformed response (missing `data`) from update_dataset") return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs] def delete_dataset(id: str) -> models.Dataset: """Delete a dataset by id. The artifacts for the dataset will be deleted :param id: Id of dataset to delete :type id: str :return: Deleted dataset details :rtype: models.Dataset """ response = _apis.datasets_v3.datasets_api.archive_dataset(dataset_id=id) if not response.data: raise RuntimeError("Received malformed response (missing `data`) from archive_dataset") return _utils.convert_to_dataclass(response.data.model_dump(), models.Dataset)
[docs] def get_dataset_timeline( id: str, *, max_items: int | None = None, direction: models.SortDirection | None = None, min_groups: int | None = None, max_ungrouped_events: int | None = None, ) -> Iterator[models.DatasetTimelineEvent]: """Get a series of dataset change events ordered by time and grouped by event type. :param id: Dataset id to get events for :type id: str :param max_items: Limit the returned generator to only produce this many items :type max_items: Optional[int] :param direction: Whether to sort in ascending or descending order :type direction: Optional[models.SortDirection] :param min_groups: How many groups are required before grouping behavior is turned on :type min_groups: Optional[int] :param max_ungrouped_events: The maximum number of events allowed before grouping behavior is turned on :type max_ungrouped_events: Optional[int] :return: Events for the dataset :rtype: Iterator[models.DatasetTimelineEvent] """ if "upload_filter" in locals(): raise TypeError( "get_dataset_timeline does not support 'upload_filter' - use 'get_view_timeline' with 'since_last_snapshot=True' to get view events since the last view snapshot" ) response = _apis.datasets_v3.datasets_api.get_dataset_timeline( dataset_id=id, limit=max_items, direction=_utils.enum_value(direction), min_groups=min_groups, max_ungrouped_events=max_ungrouped_events, ) data = response.data or [] return (_utils.convert_to_dataclass(d.model_dump(), models.DatasetTimelineEvent) for d in data)
[docs] def get_authorized_dataset_ids(ids: list[str]) -> list[str]: """Given a list of Dataset Ids, return ids from the list that the user has read access to :param ids: List of dataset ids to check :type ids: List[str] """ response = _apis.datasets_v3.datasets_api.get_authorized_dataset_ids(body=ids) return response.data or []
[docs] def create_dataset_timeline_description(id: str, description: str, timestamp: datetime): """Adds a user-defined description event for a particular timeline event group. Parameters ---------- id : str Id of the dataset. description : str Description of the timeline event group. Must be less than 200 characters. timestamp : datetime Timestamp representing the event time of the group leader to which this description will be added Raises ------ APIException If api communication fails, request is unauthorized or is unauthenticated. """ request = openapi_models.InputCreateDatasetTimelineEventRequest( description=description, timestamp=_utils.format_datetime_utc(timestamp) ) _apis.datasets_v3.datasets_api.create_dataset_timeline_event(dataset_id=id, body=request)