mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
call_openai_tool on MCP client
This commit is contained in:
parent
1f3aa82095
commit
147787b9e0
2 changed files with 90 additions and 2 deletions
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
@ -19,6 +21,27 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa
|
|||
)
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool:
|
||||
"""Convert an OpenAI tool to an MCP tool."""
|
||||
function = openai_tool["function"]
|
||||
return MCPTool(
|
||||
name=function["name"],
|
||||
description=function.get("description", ""),
|
||||
inputSchema=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
|
@ -38,3 +61,31 @@ async def load_mcp_tools(
|
|||
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
||||
]
|
||||
return tools.tools
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
) -> CallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionToolParam,
|
||||
) -> CallToolResult:
|
||||
"""Call an OpenAI tool using MCP client."""
|
||||
mcp_tool = transform_openai_tool_to_mcp_tool(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
name=mcp_tool.name,
|
||||
arguments=mcp_tool.inputSchema,
|
||||
)
|
||||
|
|
|
@ -10,7 +10,11 @@ sys.path.insert(
|
|||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
import os
|
||||
from litellm.mcp_client.tools import load_mcp_tools
|
||||
from litellm.mcp_client.tools import (
|
||||
load_mcp_tools,
|
||||
transform_openai_tool_to_mcp_tool,
|
||||
call_openai_tool,
|
||||
)
|
||||
import litellm
|
||||
import pytest
|
||||
import json
|
||||
|
@ -34,11 +38,12 @@ async def test_mcp_agent():
|
|||
print("MCP TOOLS: ", tools)
|
||||
|
||||
# Create and run the agent
|
||||
messages = [{"role": "user", "content": "what's (3 + 5)"}]
|
||||
print(os.getenv("OPENAI_API_KEY"))
|
||||
llm_response = await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
messages=[{"role": "user", "content": "what's (3 + 5) x 12?"}],
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
print("LLM RESPONSE: ", json.dumps(llm_response, indent=4, default=str))
|
||||
|
@ -51,3 +56,35 @@ async def test_mcp_agent():
|
|||
]
|
||||
== "add"
|
||||
)
|
||||
openai_tool = llm_response["choices"][0]["message"]["tool_calls"][0]
|
||||
|
||||
# Convert the OpenAI tool to an MCP tool
|
||||
mcp_tool = transform_openai_tool_to_mcp_tool(openai_tool)
|
||||
print("MCP TOOL: ", mcp_tool)
|
||||
|
||||
# Call the tool using MCP client
|
||||
call_result = await call_openai_tool(
|
||||
session=session,
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
print("CALL RESULT: ", call_result)
|
||||
|
||||
# send the tool result to the LLM
|
||||
messages.append(llm_response["choices"][0]["message"])
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(call_result.content[0].text),
|
||||
"tool_call_id": openai_tool["id"],
|
||||
}
|
||||
)
|
||||
print("final messages: ", messages)
|
||||
llm_response = await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
print(
|
||||
"FINAL LLM RESPONSE: ", json.dumps(llm_response, indent=4, default=str)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue