import importlib
import inspect
import chariot
from typing import List, Callable, Tuple, Optional, Any
import functools
from types import FunctionType
import logging
from dataclasses import dataclass
from chariot import mcp_setting
from chariot.mcp.enum import ChariotMCPPackage
from chariot.mcp.config import SDKConfig
from chariot.mcp.tools.patch import sdk_functions_to_patch
from chariot.mcp.tools.inference import (
perform_model_inference_on_datum,
get_action_docstring,
)
from chariot.mcp.utils import is_mutating, is_file_based, is_ignored, get_package_prefix
MUTATING_TOOL_DOCSTRING = "This is a mutating tool."
OBJECT_PATH_DOCSTRING = "Object path: {path}"
[docs]
def sanitize_function(function: Callable, origin: Callable = None) -> Callable:
"""
Remove spicy kwargs from a function signature/annotations (like ones
that start with '_'). Also removes untyped '**kwargs' from the signature.
"""
bad_prefixes = ["_"]
origin = origin or function
# Signature
signature = inspect.signature(origin)
new_params = []
for p in signature.parameters.values():
if any([p.name.startswith(prefix) for prefix in bad_prefixes]):
continue
if p.kind == inspect.Parameter.VAR_KEYWORD:
continue
new_params.append(p)
new_sig = signature.replace(parameters=new_params)
function.__signature__ = new_sig
# Annotations
anns = origin.__annotations__
new_anns = {
k: v for k, v in anns.items() if not any([k.startswith(prefix) for prefix in bad_prefixes])
}
function.__annotations__ = new_anns
return function
def _flatten_class(
cls: Any,
id_arg_name: str,
get_instance_by_id: Callable,
root_package: ChariotMCPPackage,
disable_mutating_tools: bool = False,
disable_file_based_tools: bool = False,
soft_failures_enabled: bool = False,
) -> List[Callable]:
"""
Create standalone functions from instances methods of the given class.
The class must be able to be instantiated using a single resource ID
in Chariot.
Using chariot.models.model.Model as an example,
the pattern is, given an instance method of this class,:
class Model(...)_
...
def method(self, ...):
...
define a function:
def Model_method(model_id, ...):
model = get_model_by_id(model_id)
return model.method(...)
This allows MCP to execute instance methods without instantiating
the class, in effect converting it into a functional framework.
Parameters
----------
cls: Any
The class. Should be something that can be instantiated using a
single resource id.
id_arg_name: str
The name of the arg that represents the resource id. (e.g. "model_id")
get_instance_by_id: Callable
A function that accepts an id and returns an instance of cls
root_package: ChariotMCPPackage
Location of the cls within the sdk
disable_mutating_tools: bool
Whether to skip methods of the class that are mutating
disable_file_based_tools: bool
Whether to skip methods of the class that require a file system
soft_failures_enabled: bool
Whether returned functions should hard fail or soft fail (raise exceptions or
return exception strings)
Returns
-------
tools: List[Callable]
The list of flattened instance methods.
"""
allowed_instance_methods = [
f
for n, f in inspect.getmembers(cls, inspect.isfunction)
if not is_ignored(f)
and not n.startswith("__")
and not isinstance(cls.__dict__.get(n), staticmethod)
and not any([hasattr(b, n) for b in cls.__bases__])
]
def patch_signature(obj: Callable, origin: Callable):
"""
Remove "self" from the signature and add {id_arg_name}: str
to the signature instead.
"""
orig_sig = inspect.signature(origin)
params = [p for p in orig_sig.parameters.values() if p.name != "self"]
id_param = inspect.Parameter(
id_arg_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
)
params = [id_param] + params
new_sig = orig_sig.replace(parameters=params)
obj.__signature__ = new_sig
# Also do annotations
anns = origin.__annotations__
anns[id_arg_name] = str
obj.__annotations__ = anns
return obj
def flatten_instance_method(method_name: str):
def wrapper(*args, **kwargs):
if id_arg_name not in kwargs:
raise ValueError(f"Missing required parameter {id_arg_name}")
id_ = kwargs.pop(id_arg_name)
instance = get_instance_by_id(id_)
return getattr(instance, method_name)(*args, **kwargs)
wrapper.__doc__ = getattr(cls, method_name).__doc__
wrapper.__name__ = getattr(cls, method_name).__name__
return wrapper
tools = []
for instance_method in allowed_instance_methods:
obj_path = f"{cls.__module__}.{cls.__name__}.{instance_method.__name__}"
obj_flat = patch_signature(
flatten_instance_method(instance_method.__name__), instance_method
)
obj_name = f"{cls.__name__}_{instance_method.__name__.strip('_')}"
# Handle mutating tools
if is_mutating(instance_method) and disable_mutating_tools:
continue
# Handle file-based tools
if is_file_based(instance_method) and disable_file_based_tools:
continue
# Handle patched functions
if obj_path in sdk_functions_to_patch:
obj_flat = sdk_functions_to_patch[obj_path]
wrapped_tool = wrap_tool(
obj_flat,
package=root_package,
obj_path=obj_path,
name_override=obj_name,
is_mutating=is_mutating(instance_method),
soft_failures_enabled=soft_failures_enabled,
)
tools.append(wrapped_tool)
return tools
def _build_sdk(packages: List[ChariotMCPPackage], sdk_config: SDKConfig) -> List[Callable]:
"""
Build sdk functions from list of packages
"""
tools = []
for package in packages:
package_name = package.value
importlib.import_module("chariot." + package_name)
package_obj = getattr(chariot, package_name)
for f in dir(package_obj):
obj = getattr(package_obj, f)
module = inspect.getmodule(obj)
if isinstance(obj, FunctionType) and module.__name__.startswith("chariot"):
obj_path = module.__name__ + "." + obj.__name__
# Ignore mcp_setting function
if obj == mcp_setting:
continue
# Handle patched functions
if obj_path in sdk_functions_to_patch:
obj = sdk_functions_to_patch[obj_path]
# Handle ignored functions
if is_ignored(obj):
continue
# Handle mutating tools
if is_mutating(obj) and sdk_config.disable_mutating_tools:
continue
# Handle file-based tools
if is_file_based(obj) and sdk_config.disable_file_based_tools:
continue
wrapped_tool = wrap_tool(
obj,
package=package,
obj_path=obj_path,
is_mutating=is_mutating(obj),
soft_failures_enabled=sdk_config.soft_failures_enabled,
)
tools.append(wrapped_tool)
return tools
def _build_models(sdk_config: SDKConfig) -> List[Callable]:
"""
Build chariot.models.model.Model
"""
tools = _flatten_class(
chariot.models.Model,
"model_id",
chariot.models.get_model_by_id,
ChariotMCPPackage.Models,
disable_mutating_tools=sdk_config.disable_mutating_tools,
disable_file_based_tools=sdk_config.disable_file_based_tools,
soft_failures_enabled=sdk_config.soft_failures_enabled,
)
# Add inference on datum tool
if not sdk_config.disable_mutating_tools:
tools.extend(
[
wrap_tool(
perform_model_inference_on_datum,
package=ChariotMCPPackage.Models,
is_mutating=True,
soft_failures_enabled=sdk_config.soft_failures_enabled,
),
wrap_tool(
get_action_docstring,
package=ChariotMCPPackage.Models,
soft_failures_enabled=sdk_config.soft_failures_enabled,
),
]
)
return tools
def _build_training_v2_Run(sdk_config: SDKConfig) -> List[Callable]:
"""
Build chariot.training_v2.run.Run
"""
return _flatten_class(
chariot.training_v2.run.Run,
"run_id",
lambda id_: chariot.training_v2.run.Run.from_id(id_),
ChariotMCPPackage.Training,
disable_mutating_tools=sdk_config.disable_mutating_tools,
disable_file_based_tools=sdk_config.disable_file_based_tools,
soft_failures_enabled=sdk_config.soft_failures_enabled,
)
[docs]
def build(
sdk_config: SDKConfig,
) -> List[Callable]:
"""
Build SDK tools with the given configuration
"""
if sdk_config.include_packages is not None:
packages = sdk_config.include_packages
elif sdk_config.exclude_packages is not None:
packages = [p for p in ChariotMCPPackage if p not in sdk_config.exclude_packages]
else:
packages = list(ChariotMCPPackage)
sdk_tools = _build_sdk(packages, sdk_config)
if ChariotMCPPackage.Models in packages:
sdk_tools.extend(_build_models(sdk_config))
if ChariotMCPPackage.Training in packages:
sdk_tools.extend(_build_training_v2_Run(sdk_config))
return sdk_tools