diff --git a/litellm/experimental_mcp_client/tools.py b/litellm/experimental_mcp_client/tools.py index aa4d02184a..f4ebbf4af4 100644 --- a/litellm/experimental_mcp_client/tools.py +++ b/litellm/experimental_mcp_client/tools.py @@ -2,12 +2,18 @@ import json from typing import List, Literal, Union from mcp import ClientSession -from mcp.types import CallToolResult +from mcp.types import CallToolRequestParams as MCPCallToolRequestParams +from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import Tool as MCPTool from openai.types.chat import ChatCompletionToolParam from openai.types.shared_params.function_definition import FunctionDefinition +from litellm.types.utils import ChatCompletionMessageToolCall + +######################################################## +# List MCP Tool functions +######################################################## def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam: """Convert an MCP tool to an OpenAI tool.""" return ChatCompletionToolParam( @@ -21,27 +27,6 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa ) -def _get_function_arguments(function: FunctionDefinition) -> dict: - """Helper to safely get and parse function arguments.""" - arguments = function.get("arguments", {}) - if isinstance(arguments, str): - try: - arguments = json.loads(arguments) - except json.JSONDecodeError: - arguments = {} - return arguments if isinstance(arguments, dict) else {} - - -def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool: - """Convert an OpenAI tool to an MCP tool.""" - function = openai_tool["function"] - return MCPTool( - name=function["name"], - description=function.get("description", ""), - inputSchema=_get_function_arguments(function), - ) - - async def load_mcp_tools( session: ClientSession, format: Literal["mcp", "openai"] = "mcp" ) -> Union[List[MCPTool], List[ChatCompletionToolParam]]: @@ -63,23 +48,49 @@ async def load_mcp_tools( return tools.tools +######################################################## +# Call MCP Tool functions +######################################################## + + async def call_mcp_tool( session: ClientSession, - name: str, - arguments: dict, -) -> CallToolResult: + call_tool_request_params: MCPCallToolRequestParams, +) -> MCPCallToolResult: """Call an MCP tool.""" tool_result = await session.call_tool( - name=name, - arguments=arguments, + name=call_tool_request_params.name, + arguments=call_tool_request_params.arguments, ) return tool_result +def _get_function_arguments(function: FunctionDefinition) -> dict: + """Helper to safely get and parse function arguments.""" + arguments = function.get("arguments", {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + return arguments if isinstance(arguments, dict) else {} + + +def _transform_openai_tool_call_to_mcp_tool_call_request( + openai_tool: ChatCompletionMessageToolCall, +) -> MCPCallToolRequestParams: + """Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams.""" + function = openai_tool["function"] + return MCPCallToolRequestParams( + name=function["name"], + arguments=_get_function_arguments(function), + ) + + async def call_openai_tool( session: ClientSession, - openai_tool: ChatCompletionToolParam, -) -> CallToolResult: + openai_tool: ChatCompletionMessageToolCall, +) -> MCPCallToolResult: """ Call an OpenAI tool using MCP client. @@ -89,11 +100,10 @@ async def call_openai_tool( Returns: The result of the MCP tool call. """ - mcp_tool = transform_openai_tool_to_mcp_tool( + mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request( openai_tool=openai_tool, ) return await call_mcp_tool( session=session, - name=mcp_tool.name, - arguments=mcp_tool.inputSchema, + call_tool_request_params=mcp_tool_call_request_params, )