diff --git a/litellm/mcp_client/tools.py b/litellm/mcp_client/tools.py index e6b403f975..bd803b995d 100644 --- a/litellm/mcp_client/tools.py +++ b/litellm/mcp_client/tools.py @@ -1,6 +1,8 @@ +import json from typing import List, Literal, Union from mcp import ClientSession +from mcp.types import CallToolResult from mcp.types import Tool as MCPTool from openai.types.chat import ChatCompletionToolParam from openai.types.shared_params.function_definition import FunctionDefinition @@ -19,6 +21,27 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa ) +def _get_function_arguments(function: FunctionDefinition) -> dict: + """Helper to safely get and parse function arguments.""" + arguments = function.get("arguments", {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + return arguments if isinstance(arguments, dict) else {} + + +def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool: + """Convert an OpenAI tool to an MCP tool.""" + function = openai_tool["function"] + return MCPTool( + name=function["name"], + description=function.get("description", ""), + inputSchema=_get_function_arguments(function), + ) + + async def load_mcp_tools( session: ClientSession, format: Literal["mcp", "openai"] = "mcp" ) -> Union[List[MCPTool], List[ChatCompletionToolParam]]: @@ -38,3 +61,31 @@ async def load_mcp_tools( transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools ] return tools.tools + + +async def call_mcp_tool( + session: ClientSession, + name: str, + arguments: dict, +) -> CallToolResult: + """Call an MCP tool.""" + tool_result = await session.call_tool( + name=name, + arguments=arguments, + ) + return tool_result + + +async def call_openai_tool( + session: ClientSession, + openai_tool: ChatCompletionToolParam, +) -> CallToolResult: + """Call an OpenAI tool using MCP client.""" + mcp_tool = transform_openai_tool_to_mcp_tool( + openai_tool=openai_tool, + ) + return await call_mcp_tool( + session=session, + name=mcp_tool.name, + arguments=mcp_tool.inputSchema, + ) diff --git a/tests/mcp_tests/test_mcp_litellm_client.py b/tests/mcp_tests/test_mcp_litellm_client.py index a4ca90eb1f..cb38614e58 100644 --- a/tests/mcp_tests/test_mcp_litellm_client.py +++ b/tests/mcp_tests/test_mcp_litellm_client.py @@ -10,7 +10,11 @@ sys.path.insert( from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import os -from litellm.mcp_client.tools import load_mcp_tools +from litellm.mcp_client.tools import ( + load_mcp_tools, + transform_openai_tool_to_mcp_tool, + call_openai_tool, +) import litellm import pytest import json @@ -34,11 +38,12 @@ async def test_mcp_agent(): print("MCP TOOLS: ", tools) # Create and run the agent + messages = [{"role": "user", "content": "what's (3 + 5)"}] print(os.getenv("OPENAI_API_KEY")) llm_response = await litellm.acompletion( model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"), - messages=[{"role": "user", "content": "what's (3 + 5) x 12?"}], + messages=messages, tools=tools, ) print("LLM RESPONSE: ", json.dumps(llm_response, indent=4, default=str)) @@ -51,3 +56,35 @@ async def test_mcp_agent(): ] == "add" ) + openai_tool = llm_response["choices"][0]["message"]["tool_calls"][0] + + # Convert the OpenAI tool to an MCP tool + mcp_tool = transform_openai_tool_to_mcp_tool(openai_tool) + print("MCP TOOL: ", mcp_tool) + + # Call the tool using MCP client + call_result = await call_openai_tool( + session=session, + openai_tool=openai_tool, + ) + print("CALL RESULT: ", call_result) + + # send the tool result to the LLM + messages.append(llm_response["choices"][0]["message"]) + messages.append( + { + "role": "tool", + "content": str(call_result.content[0].text), + "tool_call_id": openai_tool["id"], + } + ) + print("final messages: ", messages) + llm_response = await litellm.acompletion( + model="gpt-4o", + api_key=os.getenv("OPENAI_API_KEY"), + messages=messages, + tools=tools, + ) + print( + "FINAL LLM RESPONSE: ", json.dumps(llm_response, indent=4, default=str) + )