from collections.abc import Generator
from typing import Any
import chariot_api._openapi.awm.models as api_models
from chariot import _apis, mcp_setting
from chariot.awm import _utils, models
from chariot_api._openapi.awm.api import WorkflowsApi
__all__ = [
"create_workflow",
"update_workflow",
"delete_workflow",
"get_workflow",
"get_workflows",
"initialize_workflow_user",
]
def _workflow_api() -> WorkflowsApi:
return _apis.awm.workflows_api
[docs]
@mcp_setting(mutating=True)
@_apis.login_required
def create_workflow(
project_id: str,
name: str,
description: str,
config: dict[str, Any],
paused: bool = False,
) -> models.Workflow:
"""Create a workflow for a project.
:param project_id: Project id
:type project_id: str
:param name: Name of the workflow
:type name: str
:param description: Description of the workflow
:type description: str
:param config: The workflow configuration
:type config: Dict[str, Any]
:param paused: Set to true to not create an agent process pod
:type paused: bool
:rtype: models.Workflow
"""
response = _workflow_api().create_workflow(
api_models.InputCreateWorkflowRequest(
project_id=project_id,
name=name,
description=description,
config=config,
paused=paused,
)
)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from create_workflow")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Workflow)
[docs]
@mcp_setting(mutating=True)
@_apis.login_required
def update_workflow(
workflow_id: str,
name: str,
description: str,
config: dict[str, Any],
paused: bool = False,
) -> models.Workflow:
"""Update a workflow by the id.
:param workflow_id: Workflow id
:type workflow_id: str
:param name: Name of the workflow
:type name: str
:param description: Description of the workflow
:type description: str
:param config: The workflow configuration
:type config: Dict[str, Any]
:param paused: Set to true to not create an agent process pod
:type paused: bool
:rtype: models.Workflow
"""
response = _workflow_api().update_workflow(
workflow_id,
api_models.InputUpdateWorkflowRequest(
name=name,
description=description,
config=config,
paused=paused,
),
)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from update_workflow")
return _utils.convert_to_dataclass(response.data.model_dump(), models.Workflow)
[docs]
@_apis.login_required
def get_workflow(workflow_id: str) -> models.WorkflowDetails:
"""Get a workflow for a project.
:param workflow_id: Workflow id
:type workflow_id: str
:rtype: models.WorkflowDetails
"""
response = _workflow_api().get_workflow_details(workflow_id)
if not response.data:
raise RuntimeError("Received malformed response (missing `data`) from get_workflow")
return _utils.convert_to_dataclass(response.data.model_dump(), models.WorkflowDetails)
[docs]
@mcp_setting(mutating=True)
@_apis.login_required
def delete_workflow(workflow_id: str):
"""Delete a workflow for a project.
:param workflow_id: Workflow id
:type workflow_id: str
:rtype: None
"""
_workflow_api().delete_workflow(workflow_id)
[docs]
@_apis.login_required
def get_workflows(
name: str | None = None,
project_ids: list[str] | None = None,
include_deleted: bool | None = None,
limit: int | None = None,
after_key: str | None = None,
) -> Generator[models.Workflow, None, None]:
"""List workflows matching the specified filters.
:param name: Partial, case-insensitive name filter
:type name: Optional[str]
:param project_ids: List of project ids to filter
:type project_ids: Optional[List[str]]
:param include_deleted: Set to True to include deleted workflows
:type include_deleted: Optional[bool]
:param limit: The max number of items to return
:type limit: Optional[int]
:param after_key: The pagination key to fetch more items
:type after_key: Optional[str]
:rtype: Generator[models.Workflow, None, None]
"""
params = {
k: v
for k, v in {
"name": name,
"project_id": project_ids,
"include_deleted": include_deleted,
"limit": limit,
"after_key": after_key,
}.items()
if v is not None
}
def _paginate_fn(**params):
return _workflow_api().list_workflows(**params)
def _transform_fn(item):
return _utils.convert_to_dataclass(item.model_dump(), models.Workflow)
return _utils.paginate(
func=_paginate_fn,
params=params,
transform_item=_transform_fn,
)
[docs]
@mcp_setting(ignore=True)
@_apis.login_required
def initialize_workflow_user(
workflow_id: str,
agent_name: str,
):
"""Initialize a user for a particular workflow and agent.
:param workflow_id: Workflow id
:type workflow_id: str
:param agent_name: Agent name
:type agent_name: str
:rtype: None
"""
_workflow_api().initialize_session_user(
workflow_id,
api_models.InputInitializeWorkflowUserRequest(
agent_name=agent_name,
),
)