fix mcp client

This commit is contained in:
Ishaan Jaff 2025-03-21 18:18:23 -07:00
parent 1faf82f768
commit 8d770f0ccf

View file

@ -2,12 +2,18 @@ import json
from typing import List, Literal, Union from typing import List, Literal, Union
from mcp import ClientSession from mcp import ClientSession
from mcp.types import CallToolResult from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import Tool as MCPTool from mcp.types import Tool as MCPTool
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition from openai.types.shared_params.function_definition import FunctionDefinition
from litellm.types.utils import ChatCompletionMessageToolCall
########################################################
# List MCP Tool functions
########################################################
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam: def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
"""Convert an MCP tool to an OpenAI tool.""" """Convert an MCP tool to an OpenAI tool."""
return ChatCompletionToolParam( return ChatCompletionToolParam(
@ -21,27 +27,6 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa
) )
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool:
"""Convert an OpenAI tool to an MCP tool."""
function = openai_tool["function"]
return MCPTool(
name=function["name"],
description=function.get("description", ""),
inputSchema=_get_function_arguments(function),
)
async def load_mcp_tools( async def load_mcp_tools(
session: ClientSession, format: Literal["mcp", "openai"] = "mcp" session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]: ) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
@ -63,23 +48,49 @@ async def load_mcp_tools(
return tools.tools return tools.tools
########################################################
# Call MCP Tool functions
########################################################
async def call_mcp_tool( async def call_mcp_tool(
session: ClientSession, session: ClientSession,
name: str, call_tool_request_params: MCPCallToolRequestParams,
arguments: dict, ) -> MCPCallToolResult:
) -> CallToolResult:
"""Call an MCP tool.""" """Call an MCP tool."""
tool_result = await session.call_tool( tool_result = await session.call_tool(
name=name, name=call_tool_request_params.name,
arguments=arguments, arguments=call_tool_request_params.arguments,
) )
return tool_result return tool_result
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool: ChatCompletionMessageToolCall,
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
return MCPCallToolRequestParams(
name=function["name"],
arguments=_get_function_arguments(function),
)
async def call_openai_tool( async def call_openai_tool(
session: ClientSession, session: ClientSession,
openai_tool: ChatCompletionToolParam, openai_tool: ChatCompletionMessageToolCall,
) -> CallToolResult: ) -> MCPCallToolResult:
""" """
Call an OpenAI tool using MCP client. Call an OpenAI tool using MCP client.
@ -89,11 +100,10 @@ async def call_openai_tool(
Returns: Returns:
The result of the MCP tool call. The result of the MCP tool call.
""" """
mcp_tool = transform_openai_tool_to_mcp_tool( mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool=openai_tool, openai_tool=openai_tool,
) )
return await call_mcp_tool( return await call_mcp_tool(
session=session, session=session,
name=mcp_tool.name, call_tool_request_params=mcp_tool_call_request_params,
arguments=mcp_tool.inputSchema,
) )