mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix mcp client
This commit is contained in:
parent
1faf82f768
commit
8d770f0ccf
1 changed files with 42 additions and 32 deletions
|
@ -2,12 +2,18 @@ import json
|
||||||
from typing import List, Literal, Union
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.types import CallToolResult
|
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||||
|
from mcp.types import CallToolResult as MCPCallToolResult
|
||||||
from mcp.types import Tool as MCPTool
|
from mcp.types import Tool as MCPTool
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||||
|
|
||||||
|
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||||
|
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
# List MCP Tool functions
|
||||||
|
########################################################
|
||||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||||
"""Convert an MCP tool to an OpenAI tool."""
|
"""Convert an MCP tool to an OpenAI tool."""
|
||||||
return ChatCompletionToolParam(
|
return ChatCompletionToolParam(
|
||||||
|
@ -21,27 +27,6 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
|
||||||
"""Helper to safely get and parse function arguments."""
|
|
||||||
arguments = function.get("arguments", {})
|
|
||||||
if isinstance(arguments, str):
|
|
||||||
try:
|
|
||||||
arguments = json.loads(arguments)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
arguments = {}
|
|
||||||
return arguments if isinstance(arguments, dict) else {}
|
|
||||||
|
|
||||||
|
|
||||||
def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool:
|
|
||||||
"""Convert an OpenAI tool to an MCP tool."""
|
|
||||||
function = openai_tool["function"]
|
|
||||||
return MCPTool(
|
|
||||||
name=function["name"],
|
|
||||||
description=function.get("description", ""),
|
|
||||||
inputSchema=_get_function_arguments(function),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||||
|
@ -63,23 +48,49 @@ async def load_mcp_tools(
|
||||||
return tools.tools
|
return tools.tools
|
||||||
|
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
# Call MCP Tool functions
|
||||||
|
########################################################
|
||||||
|
|
||||||
|
|
||||||
async def call_mcp_tool(
|
async def call_mcp_tool(
|
||||||
session: ClientSession,
|
session: ClientSession,
|
||||||
name: str,
|
call_tool_request_params: MCPCallToolRequestParams,
|
||||||
arguments: dict,
|
) -> MCPCallToolResult:
|
||||||
) -> CallToolResult:
|
|
||||||
"""Call an MCP tool."""
|
"""Call an MCP tool."""
|
||||||
tool_result = await session.call_tool(
|
tool_result = await session.call_tool(
|
||||||
name=name,
|
name=call_tool_request_params.name,
|
||||||
arguments=arguments,
|
arguments=call_tool_request_params.arguments,
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||||
|
"""Helper to safely get and parse function arguments."""
|
||||||
|
arguments = function.get("arguments", {})
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
try:
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {}
|
||||||
|
return arguments if isinstance(arguments, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||||
|
openai_tool: ChatCompletionMessageToolCall,
|
||||||
|
) -> MCPCallToolRequestParams:
|
||||||
|
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||||
|
function = openai_tool["function"]
|
||||||
|
return MCPCallToolRequestParams(
|
||||||
|
name=function["name"],
|
||||||
|
arguments=_get_function_arguments(function),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def call_openai_tool(
|
async def call_openai_tool(
|
||||||
session: ClientSession,
|
session: ClientSession,
|
||||||
openai_tool: ChatCompletionToolParam,
|
openai_tool: ChatCompletionMessageToolCall,
|
||||||
) -> CallToolResult:
|
) -> MCPCallToolResult:
|
||||||
"""
|
"""
|
||||||
Call an OpenAI tool using MCP client.
|
Call an OpenAI tool using MCP client.
|
||||||
|
|
||||||
|
@ -89,11 +100,10 @@ async def call_openai_tool(
|
||||||
Returns:
|
Returns:
|
||||||
The result of the MCP tool call.
|
The result of the MCP tool call.
|
||||||
"""
|
"""
|
||||||
mcp_tool = transform_openai_tool_to_mcp_tool(
|
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||||
openai_tool=openai_tool,
|
openai_tool=openai_tool,
|
||||||
)
|
)
|
||||||
return await call_mcp_tool(
|
return await call_mcp_tool(
|
||||||
session=session,
|
session=session,
|
||||||
name=mcp_tool.name,
|
call_tool_request_params=mcp_tool_call_request_params,
|
||||||
arguments=mcp_tool.inputSchema,
|
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue