import importlib
import inspect
import chariot
from typing import List, Callable, Tuple, Optional, Any
import functools
from types import FunctionType
import logging
from chariot.mcp.enum import ChariotMCPPackage
from chariot.mcp.config import SDKConfig
from chariot.mcp.tools.patch import sdk_functions_to_patch
from chariot.mcp.tools.inference import (
perform_model_inference_on_datum,
get_action_docstring,
)
from chariot.mcp.utils import is_mutating, is_file_based, is_ignored, get_package_prefix
MUTATING_TOOL_DOCSTRING = "This is a mutating tool."
OBJECT_PATH_DOCSTRING = "Object path: {path}"
[docs]
def sanitize_function(function: Callable, origin: Callable = None) -> Callable:
"""
Remove spicy kwargs from a function signature/annotations (like ones
that start with '_'). Also removes untyped '**kwargs' from the signature.
"""
bad_prefixes = ["_"]
origin = origin or function
# Signature
signature = inspect.signature(origin)
new_params = []
for p in signature.parameters.values():
if any([p.name.startswith(prefix) for prefix in bad_prefixes]):
continue
if p.kind == inspect.Parameter.VAR_KEYWORD:
continue
new_params.append(p)
new_sig = signature.replace(parameters=new_params)
function.__signature__ = new_sig
# Annotations
anns = origin.__annotations__
new_anns = {
k: v for k, v in anns.items() if not any([k.startswith(prefix) for prefix in bad_prefixes])
}
function.__annotations__ = new_anns
return function
def _flatten_class(
cls: Any,
id_arg_name: str,
get_instance_by_id: Callable,
root_package: ChariotMCPPackage,
disable_mutating_tools: bool = False,
disable_file_based_tools: bool = False,
) -> List[Callable]:
"""
Create standalone functions from instances methods of the given class.
The class must be able to be instantiated using a single resource ID
in Chariot.
Using chariot.models.model.Model as an example,
the pattern is, given an instance method of this class,:
class Model(...)_
...
def method(self, ...):
...
define a function:
def Model_method(model_id, ...):
model = get_model_by_id(model_id)
return model.method(...)
This allows MCP to execute instance methods without instantiating
the class, in effect converting it into a functional framework.
Parameters
----------
cls: Any
The class. Should be something that can be instantiated using a
single resource id.
id_arg_name: str
The name of the arg that represents the resource id. (e.g. "model_id")
get_instance_by_id: Callable
A function that accepts an id and returns an instance of cls
root_package: ChariotMCPPackage
Location of the cls within the sdk
disable_mutating_tools: bool
Whether to skip methods of the class that are mutating
disable_file_based_tools: bool
Whether to skip methods of the class that require a file system
Returns
-------
tools: List[Callable]
The list of flattened instance methods.
"""
allowed_instance_methods = [
f
for n, f in inspect.getmembers(cls, inspect.isfunction)
if not is_ignored(f)
and not n.startswith("__")
and not isinstance(cls.__dict__.get(n), staticmethod)
and not any([hasattr(b, n) for b in cls.__bases__])
]
def patch_signature(obj: Callable, origin: Callable):
"""
Remove "self" from the signature and add {id_arg_name}: str
to the signature instead.
"""
orig_sig = inspect.signature(origin)
params = [p for p in orig_sig.parameters.values() if p.name != "self"]
id_param = inspect.Parameter(
id_arg_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
)
params = [id_param] + params
new_sig = orig_sig.replace(parameters=params)
obj.__signature__ = new_sig
# Also do annotations
anns = origin.__annotations__
anns[id_arg_name] = str
obj.__annotations__ = anns
return obj
def flatten_instance_method(method_name: str):
def wrapper(*args, **kwargs):
if id_arg_name not in kwargs:
raise ValueError(f"Missing required parameter {id_arg_name}")
id_ = kwargs.pop(id_arg_name)
instance = get_instance_by_id(id_)
return getattr(instance, method_name)(*args, **kwargs)
wrapper.__doc__ = getattr(cls, method_name).__doc__
wrapper.__name__ = getattr(cls, method_name).__name__
return wrapper
tools = []
for instance_method in allowed_instance_methods:
# Handle mutating tools
if is_mutating(instance_method) and disable_mutating_tools:
continue
# Handle file-based tools
if is_file_based(instance_method) and disable_file_based_tools:
continue
obj_flat = flatten_instance_method(instance_method.__name__)
obj_name = f"{cls.__name__}_{instance_method.__name__.strip('_')}"
wrapped_tool = wrap_tool(
patch_signature(obj_flat, instance_method),
package=root_package,
obj_path=f"{cls.__module__}.{cls.__name__}.{instance_method.__name__}",
name_override=obj_name,
is_mutating=is_mutating(instance_method),
)
tools.append(wrapped_tool)
return tools
def _build_sdk(
packages: List[ChariotMCPPackage],
disable_mutating_tools: bool = False,
disable_file_based_tools: bool = False,
) -> List[Callable]:
"""
Build sdk functions from list of packages
"""
tools = []
for package in packages:
package_name = package.value
importlib.import_module("chariot." + package_name)
package_obj = getattr(chariot, package_name)
for f in dir(package_obj):
obj = getattr(package_obj, f)
module = inspect.getmodule(obj)
if isinstance(obj, FunctionType) and module.__name__.startswith("chariot"):
obj_path = module.__name__ + "." + obj.__name__
# Handle patched functions
if obj_path in sdk_functions_to_patch:
obj = sdk_functions_to_patch[obj_path]
# Handle ignored functions
if is_ignored(obj):
continue
# Handle mutating tools
if is_mutating(obj) and disable_mutating_tools:
continue
# Handle file-based tools
if is_file_based(obj) and disable_file_based_tools:
continue
wrapped_tool = wrap_tool(
obj, package=package, obj_path=obj_path, is_mutating=is_mutating(obj)
)
tools.append(wrapped_tool)
return tools
def _build_models(
disable_mutating_tools: bool = False, disable_file_based_tools: bool = False
) -> List[Callable]:
"""
Build chariot.models.model.Model
"""
tools = _flatten_class(
chariot.models.Model,
"model_id",
chariot.models.get_model_by_id,
ChariotMCPPackage.Models,
disable_mutating_tools=disable_mutating_tools,
disable_file_based_tools=disable_file_based_tools,
)
# Add inference on datum tool
if not disable_mutating_tools:
tools.extend(
[
wrap_tool(
perform_model_inference_on_datum,
package=ChariotMCPPackage.Models,
is_mutating=True,
),
wrap_tool(
get_action_docstring,
package=ChariotMCPPackage.Models,
),
]
)
return tools
def _build_training_v2_Run(
disable_mutating_tools: bool = False, disable_file_based_tools: bool = False
) -> List[Callable]:
"""
Build chariot.training_v2.run.Run
"""
return _flatten_class(
chariot.training_v2.run.Run,
"run_id",
lambda id_: chariot.training_v2.run.Run.from_id(id_),
ChariotMCPPackage.Training,
disable_mutating_tools=disable_mutating_tools,
disable_file_based_tools=disable_file_based_tools,
)
[docs]
def build(
sdk_config: SDKConfig,
) -> List[Callable]:
"""
Build SDK tools with the given configuration
"""
if sdk_config.include_packages is not None:
packages = sdk_config.include_packages
elif sdk_config.exclude_packages is not None:
packages = [p for p in ChariotMCPPackage if p not in sdk_config.exclude_packages]
else:
packages = list(ChariotMCPPackage)
sdk_tools = _build_sdk(
packages,
disable_mutating_tools=sdk_config.disable_mutating_tools,
disable_file_based_tools=sdk_config.disable_file_based_tools,
)
if ChariotMCPPackage.Models in packages:
sdk_tools.extend(
_build_models(
disable_mutating_tools=sdk_config.disable_mutating_tools,
disable_file_based_tools=sdk_config.disable_file_based_tools,
)
)
if ChariotMCPPackage.Training in packages:
sdk_tools.extend(
_build_training_v2_Run(
disable_mutating_tools=sdk_config.disable_mutating_tools,
disable_file_based_tools=sdk_config.disable_file_based_tools,
)
)
return sdk_tools