mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix mcp client
This commit is contained in:
parent
1faf82f768
commit
8d770f0ccf
1 changed files with 42 additions and 32 deletions
|
@ -2,12 +2,18 @@ import json
|
|||
from typing import List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
||||
########################################################
|
||||
# List MCP Tool functions
|
||||
########################################################
|
||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI tool."""
|
||||
return ChatCompletionToolParam(
|
||||
|
@ -21,27 +27,6 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa
|
|||
)
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool:
|
||||
"""Convert an OpenAI tool to an MCP tool."""
|
||||
function = openai_tool["function"]
|
||||
return MCPTool(
|
||||
name=function["name"],
|
||||
description=function.get("description", ""),
|
||||
inputSchema=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
|
@ -63,23 +48,49 @@ async def load_mcp_tools(
|
|||
return tools.tools
|
||||
|
||||
|
||||
########################################################
|
||||
# Call MCP Tool functions
|
||||
########################################################
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
) -> CallToolResult:
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
) -> MCPCallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
return MCPCallToolRequestParams(
|
||||
name=function["name"],
|
||||
arguments=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionToolParam,
|
||||
) -> CallToolResult:
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an OpenAI tool using MCP client.
|
||||
|
||||
|
@ -89,11 +100,10 @@ async def call_openai_tool(
|
|||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool = transform_openai_tool_to_mcp_tool(
|
||||
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
name=mcp_tool.name,
|
||||
arguments=mcp_tool.inputSchema,
|
||||
call_tool_request_params=mcp_tool_call_request_params,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue