call_openai_tool on MCP client

This commit is contained in:
Ishaan Jaff 2025-03-21 14:36:32 -07:00
parent bde703b90c
commit 0b021b8334
2 changed files with 90 additions and 2 deletions

View file

@ -1,6 +1,8 @@
import json
from typing import List, Literal, Union
from mcp import ClientSession
from mcp.types import CallToolResult
from mcp.types import Tool as MCPTool
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
@ -19,6 +21,27 @@ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolPa
)
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def transform_openai_tool_to_mcp_tool(openai_tool: ChatCompletionToolParam) -> MCPTool:
"""Convert an OpenAI tool to an MCP tool."""
function = openai_tool["function"]
return MCPTool(
name=function["name"],
description=function.get("description", ""),
inputSchema=_get_function_arguments(function),
)
async def load_mcp_tools(
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
@ -38,3 +61,31 @@ async def load_mcp_tools(
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
]
return tools.tools
async def call_mcp_tool(
session: ClientSession,
name: str,
arguments: dict,
) -> CallToolResult:
"""Call an MCP tool."""
tool_result = await session.call_tool(
name=name,
arguments=arguments,
)
return tool_result
async def call_openai_tool(
session: ClientSession,
openai_tool: ChatCompletionToolParam,
) -> CallToolResult:
"""Call an OpenAI tool using MCP client."""
mcp_tool = transform_openai_tool_to_mcp_tool(
openai_tool=openai_tool,
)
return await call_mcp_tool(
session=session,
name=mcp_tool.name,
arguments=mcp_tool.inputSchema,
)