Source code for chariot.mcp.build

import importlib
import inspect
import chariot
from typing import List, Callable, Tuple, Optional, Any
import functools
from types import FunctionType
import logging

from chariot.mcp.enum import ChariotMCPPackage
from chariot.mcp.config import SDKConfig
from chariot.mcp.tools.patch import sdk_functions_to_patch
from chariot.mcp.tools.inference import (
    perform_model_inference_on_datum,
    get_action_docstring,
)
from chariot.mcp.utils import is_mutating, is_file_based, is_ignored, get_package_prefix

MUTATING_TOOL_DOCSTRING = "This is a mutating tool."
OBJECT_PATH_DOCSTRING = "Object path: {path}"


[docs] def sanitize_function(function: Callable, origin: Callable = None) -> Callable: """ Remove spicy kwargs from a function signature/annotations (like ones that start with '_'). Also removes untyped '**kwargs' from the signature. """ bad_prefixes = ["_"] origin = origin or function # Signature signature = inspect.signature(origin) new_params = [] for p in signature.parameters.values(): if any([p.name.startswith(prefix) for prefix in bad_prefixes]): continue if p.kind == inspect.Parameter.VAR_KEYWORD: continue new_params.append(p) new_sig = signature.replace(parameters=new_params) function.__signature__ = new_sig # Annotations anns = origin.__annotations__ new_anns = { k: v for k, v in anns.items() if not any([k.startswith(prefix) for prefix in bad_prefixes]) } function.__annotations__ = new_anns return function
[docs] def wrap_tool( func: Callable, package: ChariotMCPPackage = None, obj_path: str = None, name_override: str = None, is_mutating: bool = False, ): """ Wrap a tool with extra data: 1. Add a prefix to the tool name based on the originating parent chariot package 2. Add the Object path to the docstring 3. Sanitize the signature to remove bad kwargs """ def create_wrapper(wrapped): def wrapper(*args, **kwargs): return wrapped(*args, **kwargs) # Add package prefix to name base_name = name_override or wrapped.__name__ wrapper.__name__ = get_package_prefix(package) + base_name # Add object path to docstring wrapper.__doc__ = wrapped.__doc__ or "" if obj_path is not None: d = OBJECT_PATH_DOCSTRING.format(path=obj_path) wrapper.__doc__ += f"\n\n{d}" if is_mutating: wrapper.__doc__ += f"\n\n{MUTATING_TOOL_DOCSTRING}" # Sanitize signature wrapper = sanitize_function(wrapper, origin=wrapped) return wrapper return create_wrapper(func)
def _flatten_class( cls: Any, id_arg_name: str, get_instance_by_id: Callable, root_package: ChariotMCPPackage, disable_mutating_tools: bool = False, disable_file_based_tools: bool = False, ) -> List[Callable]: """ Create standalone functions from instances methods of the given class. The class must be able to be instantiated using a single resource ID in Chariot. Using chariot.models.model.Model as an example, the pattern is, given an instance method of this class,: class Model(...)_ ... def method(self, ...): ... define a function: def Model_method(model_id, ...): model = get_model_by_id(model_id) return model.method(...) This allows MCP to execute instance methods without instantiating the class, in effect converting it into a functional framework. Parameters ---------- cls: Any The class. Should be something that can be instantiated using a single resource id. id_arg_name: str The name of the arg that represents the resource id. (e.g. "model_id") get_instance_by_id: Callable A function that accepts an id and returns an instance of cls root_package: ChariotMCPPackage Location of the cls within the sdk disable_mutating_tools: bool Whether to skip methods of the class that are mutating disable_file_based_tools: bool Whether to skip methods of the class that require a file system Returns ------- tools: List[Callable] The list of flattened instance methods. """ allowed_instance_methods = [ f for n, f in inspect.getmembers(cls, inspect.isfunction) if not is_ignored(f) and not n.startswith("__") and not isinstance(cls.__dict__.get(n), staticmethod) and not any([hasattr(b, n) for b in cls.__bases__]) ] def patch_signature(obj: Callable, origin: Callable): """ Remove "self" from the signature and add {id_arg_name}: str to the signature instead. """ orig_sig = inspect.signature(origin) params = [p for p in orig_sig.parameters.values() if p.name != "self"] id_param = inspect.Parameter( id_arg_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str ) params = [id_param] + params new_sig = orig_sig.replace(parameters=params) obj.__signature__ = new_sig # Also do annotations anns = origin.__annotations__ anns[id_arg_name] = str obj.__annotations__ = anns return obj def flatten_instance_method(method_name: str): def wrapper(*args, **kwargs): if id_arg_name not in kwargs: raise ValueError(f"Missing required parameter {id_arg_name}") id_ = kwargs.pop(id_arg_name) instance = get_instance_by_id(id_) return getattr(instance, method_name)(*args, **kwargs) wrapper.__doc__ = getattr(cls, method_name).__doc__ wrapper.__name__ = getattr(cls, method_name).__name__ return wrapper tools = [] for instance_method in allowed_instance_methods: # Handle mutating tools if is_mutating(instance_method) and disable_mutating_tools: continue # Handle file-based tools if is_file_based(instance_method) and disable_file_based_tools: continue obj_flat = flatten_instance_method(instance_method.__name__) obj_name = f"{cls.__name__}_{instance_method.__name__.strip('_')}" wrapped_tool = wrap_tool( patch_signature(obj_flat, instance_method), package=root_package, obj_path=f"{cls.__module__}.{cls.__name__}.{instance_method.__name__}", name_override=obj_name, is_mutating=is_mutating(instance_method), ) tools.append(wrapped_tool) return tools def _build_sdk( packages: List[ChariotMCPPackage], disable_mutating_tools: bool = False, disable_file_based_tools: bool = False, ) -> List[Callable]: """ Build sdk functions from list of packages """ tools = [] for package in packages: package_name = package.value importlib.import_module("chariot." + package_name) package_obj = getattr(chariot, package_name) for f in dir(package_obj): obj = getattr(package_obj, f) module = inspect.getmodule(obj) if isinstance(obj, FunctionType) and module.__name__.startswith("chariot"): obj_path = module.__name__ + "." + obj.__name__ # Handle patched functions if obj_path in sdk_functions_to_patch: obj = sdk_functions_to_patch[obj_path] # Handle ignored functions if is_ignored(obj): continue # Handle mutating tools if is_mutating(obj) and disable_mutating_tools: continue # Handle file-based tools if is_file_based(obj) and disable_file_based_tools: continue wrapped_tool = wrap_tool( obj, package=package, obj_path=obj_path, is_mutating=is_mutating(obj) ) tools.append(wrapped_tool) return tools def _build_models( disable_mutating_tools: bool = False, disable_file_based_tools: bool = False ) -> List[Callable]: """ Build chariot.models.model.Model """ tools = _flatten_class( chariot.models.Model, "model_id", chariot.models.get_model_by_id, ChariotMCPPackage.Models, disable_mutating_tools=disable_mutating_tools, disable_file_based_tools=disable_file_based_tools, ) # Add inference on datum tool if not disable_mutating_tools: tools.extend( [ wrap_tool( perform_model_inference_on_datum, package=ChariotMCPPackage.Models, is_mutating=True, ), wrap_tool( get_action_docstring, package=ChariotMCPPackage.Models, ), ] ) return tools def _build_training_v2_Run( disable_mutating_tools: bool = False, disable_file_based_tools: bool = False ) -> List[Callable]: """ Build chariot.training_v2.run.Run """ return _flatten_class( chariot.training_v2.run.Run, "run_id", lambda id_: chariot.training_v2.run.Run.from_id(id_), ChariotMCPPackage.Training, disable_mutating_tools=disable_mutating_tools, disable_file_based_tools=disable_file_based_tools, )
[docs] def build( sdk_config: SDKConfig, ) -> List[Callable]: """ Build SDK tools with the given configuration """ if sdk_config.include_packages is not None: packages = sdk_config.include_packages elif sdk_config.exclude_packages is not None: packages = [p for p in ChariotMCPPackage if p not in sdk_config.exclude_packages] else: packages = list(ChariotMCPPackage) sdk_tools = _build_sdk( packages, disable_mutating_tools=sdk_config.disable_mutating_tools, disable_file_based_tools=sdk_config.disable_file_based_tools, ) if ChariotMCPPackage.Models in packages: sdk_tools.extend( _build_models( disable_mutating_tools=sdk_config.disable_mutating_tools, disable_file_based_tools=sdk_config.disable_file_based_tools, ) ) if ChariotMCPPackage.Training in packages: sdk_tools.extend( _build_training_v2_Run( disable_mutating_tools=sdk_config.disable_mutating_tools, disable_file_based_tools=sdk_config.disable_file_based_tools, ) ) return sdk_tools