mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #9642 from BerriAI/litellm_mcp_improvements_expose_sse_urls
[Feat] - MCP improvements, add support for using SSE MCP servers
This commit is contained in:
commit
5df985f964
25 changed files with 1220 additions and 30 deletions
|
@ -418,6 +418,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
|
||||||
|
|
||||||
########################### Logging Callback Constants ###########################
|
########################### Logging Callback Constants ###########################
|
||||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||||
|
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||||
|
|
||||||
########################### LiteLLM Proxy Specific Constants ###########################
|
########################### LiteLLM Proxy Specific Constants ###########################
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||||
|
@ -76,8 +76,8 @@ def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||||
return arguments if isinstance(arguments, dict) else {}
|
return arguments if isinstance(arguments, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
def _transform_openai_tool_call_to_mcp_tool_call_request(
|
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||||
openai_tool: ChatCompletionMessageToolCall,
|
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||||
) -> MCPCallToolRequestParams:
|
) -> MCPCallToolRequestParams:
|
||||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||||
function = openai_tool["function"]
|
function = openai_tool["function"]
|
||||||
|
@ -100,8 +100,10 @@ async def call_openai_tool(
|
||||||
Returns:
|
Returns:
|
||||||
The result of the MCP tool call.
|
The result of the MCP tool call.
|
||||||
"""
|
"""
|
||||||
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
|
mcp_tool_call_request_params = (
|
||||||
openai_tool=openai_tool,
|
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||||
|
openai_tool=openai_tool,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return await call_mcp_tool(
|
return await call_mcp_tool(
|
||||||
session=session,
|
session=session,
|
||||||
|
|
|
@ -67,6 +67,7 @@ from litellm.types.utils import (
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
StandardLoggingAdditionalHeaders,
|
StandardLoggingAdditionalHeaders,
|
||||||
StandardLoggingHiddenParams,
|
StandardLoggingHiddenParams,
|
||||||
|
StandardLoggingMCPToolCall,
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
StandardLoggingModelCostFailureDebugInformation,
|
StandardLoggingModelCostFailureDebugInformation,
|
||||||
StandardLoggingModelInformation,
|
StandardLoggingModelInformation,
|
||||||
|
@ -1095,7 +1096,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
elif isinstance(result, dict): # pass-through endpoints
|
elif isinstance(result, dict) or isinstance(result, list):
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"standard_logging_object"
|
"standard_logging_object"
|
||||||
|
@ -3106,6 +3107,7 @@ class StandardLoggingPayloadSetup:
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
prompt_integration: Optional[str] = None,
|
prompt_integration: Optional[str] = None,
|
||||||
applied_guardrails: Optional[List[str]] = None,
|
applied_guardrails: Optional[List[str]] = None,
|
||||||
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
"""
|
"""
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
@ -3152,6 +3154,7 @@ class StandardLoggingPayloadSetup:
|
||||||
user_api_key_end_user_id=None,
|
user_api_key_end_user_id=None,
|
||||||
prompt_management_metadata=prompt_management_metadata,
|
prompt_management_metadata=prompt_management_metadata,
|
||||||
applied_guardrails=applied_guardrails,
|
applied_guardrails=applied_guardrails,
|
||||||
|
mcp_tool_call_metadata=mcp_tool_call_metadata,
|
||||||
)
|
)
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
# Filter the metadata dictionary to include only the specified keys
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
@ -3478,6 +3481,7 @@ def get_standard_logging_object_payload(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
prompt_integration=kwargs.get("prompt_integration", None),
|
prompt_integration=kwargs.get("prompt_integration", None),
|
||||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||||
|
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
_request_body = proxy_server_request.get("body", {})
|
_request_body = proxy_server_request.get("body", {})
|
||||||
|
@ -3617,6 +3621,7 @@ def get_standard_logging_metadata(
|
||||||
user_api_key_end_user_id=None,
|
user_api_key_end_user_id=None,
|
||||||
prompt_management_metadata=None,
|
prompt_management_metadata=None,
|
||||||
applied_guardrails=None,
|
applied_guardrails=None,
|
||||||
|
mcp_tool_call_metadata=None,
|
||||||
)
|
)
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
# Filter the metadata dictionary to include only the specified keys
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
|
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
"""
|
||||||
|
MCP Client Manager
|
||||||
|
|
||||||
|
This class is responsible for managing MCP SSE clients.
|
||||||
|
|
||||||
|
This is a Proxy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.types import Tool as MCPTool
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.mcp_servers: List[MCPSSEServer] = []
|
||||||
|
"""
|
||||||
|
eg.
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "zapier_mcp_server",
|
||||||
|
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "google_drive_mcp_server",
|
||||||
|
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"gmail_send_email": "zapier_mcp_server",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Load the MCP Servers from the config
|
||||||
|
"""
|
||||||
|
for server_name, server_config in mcp_servers_config.items():
|
||||||
|
_mcp_info: dict = server_config.get("mcp_info", None) or {}
|
||||||
|
mcp_info = MCPInfo(**_mcp_info)
|
||||||
|
mcp_info["server_name"] = server_name
|
||||||
|
self.mcp_servers.append(
|
||||||
|
MCPSSEServer(
|
||||||
|
name=server_name,
|
||||||
|
url=server_config["url"],
|
||||||
|
mcp_info=mcp_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||||
|
|
||||||
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
|
"""
|
||||||
|
List all tools available across all MCP Servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[MCPTool]: Combined list of tools from all servers
|
||||||
|
"""
|
||||||
|
list_tools_result: List[MCPTool] = []
|
||||||
|
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
|
||||||
|
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
tools = await self._get_tools_from_server(server)
|
||||||
|
list_tools_result.extend(tools)
|
||||||
|
|
||||||
|
return list_tools_result
|
||||||
|
|
||||||
|
async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
|
||||||
|
"""
|
||||||
|
Helper method to get tools from a single MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server (MCPSSEServer): The server to query tools from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[MCPTool]: List of tools available on the server
|
||||||
|
"""
|
||||||
|
verbose_logger.debug(f"Connecting to url: {server.url}")
|
||||||
|
|
||||||
|
async with sse_client(url=server.url) as (read, write):
|
||||||
|
async with ClientSession(read, write) as session:
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
|
||||||
|
|
||||||
|
# Update tool to server mapping
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||||
|
|
||||||
|
return tools_result.tools
|
||||||
|
|
||||||
|
def initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||||
|
"""
|
||||||
|
On startup, initialize the tool name to MCP server name mapping
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if asyncio.get_running_loop():
|
||||||
|
asyncio.create_task(
|
||||||
|
self._initialize_tool_name_to_mcp_server_name_mapping()
|
||||||
|
)
|
||||||
|
except RuntimeError as e: # no running event loop
|
||||||
|
verbose_logger.exception(
|
||||||
|
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||||
|
"""
|
||||||
|
Call list_tools for each server and update the tool name to MCP server name mapping
|
||||||
|
"""
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
tools = await self._get_tools_from_server(server)
|
||||||
|
for tool in tools:
|
||||||
|
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||||
|
|
||||||
|
async def call_tool(self, name: str, arguments: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Call a tool with the given name and arguments
|
||||||
|
"""
|
||||||
|
mcp_server = self._get_mcp_server_from_tool_name(name)
|
||||||
|
if mcp_server is None:
|
||||||
|
raise ValueError(f"Tool {name} not found")
|
||||||
|
async with sse_client(url=mcp_server.url) as (read, write):
|
||||||
|
async with ClientSession(read, write) as session:
|
||||||
|
await session.initialize()
|
||||||
|
return await session.call_tool(name, arguments)
|
||||||
|
|
||||||
|
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
|
||||||
|
"""
|
||||||
|
Get the MCP Server from the tool name
|
||||||
|
"""
|
||||||
|
if tool_name in self.tool_name_to_mcp_server_name_mapping:
|
||||||
|
for server in self.mcp_servers:
|
||||||
|
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
|
||||||
|
return server
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
global_mcp_server_manager: MCPServerManager = MCPServerManager()
|
|
@ -3,14 +3,21 @@ LiteLLM MCP Server Routes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from anyio import BrokenResourceError
|
from anyio import BrokenResourceError
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import ValidationError
|
from pydantic import ConfigDict, ValidationError
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.constants import MCP_TOOL_NAME_PREFIX
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||||
|
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||||
|
from litellm.utils import client
|
||||||
|
|
||||||
# Check if MCP is available
|
# Check if MCP is available
|
||||||
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
||||||
|
@ -36,9 +43,23 @@ if MCP_AVAILABLE:
|
||||||
from mcp.types import TextContent as MCPTextContent
|
from mcp.types import TextContent as MCPTextContent
|
||||||
from mcp.types import Tool as MCPTool
|
from mcp.types import Tool as MCPTool
|
||||||
|
|
||||||
|
from .mcp_server_manager import global_mcp_server_manager
|
||||||
from .sse_transport import SseServerTransport
|
from .sse_transport import SseServerTransport
|
||||||
from .tool_registry import global_mcp_tool_registry
|
from .tool_registry import global_mcp_tool_registry
|
||||||
|
|
||||||
|
######################################################
|
||||||
|
############ MCP Tools List REST API Response Object #
|
||||||
|
# Defined here because we don't want to add `mcp` as a
|
||||||
|
# required dependency for `litellm` pip package
|
||||||
|
######################################################
|
||||||
|
class ListMCPToolsRestAPIResponseObject(MCPTool):
|
||||||
|
"""
|
||||||
|
Object returned by the /tools/list REST API route.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mcp_info: Optional[MCPInfo] = None
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
############ Initialize the MCP Server #################
|
############ Initialize the MCP Server #################
|
||||||
########################################################
|
########################################################
|
||||||
|
@ -52,9 +73,14 @@ if MCP_AVAILABLE:
|
||||||
########################################################
|
########################################################
|
||||||
############### MCP Server Routes #######################
|
############### MCP Server Routes #######################
|
||||||
########################################################
|
########################################################
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[MCPTool]:
|
async def list_tools() -> list[MCPTool]:
|
||||||
|
"""
|
||||||
|
List all available tools
|
||||||
|
"""
|
||||||
|
return await _list_mcp_tools()
|
||||||
|
|
||||||
|
async def _list_mcp_tools() -> List[MCPTool]:
|
||||||
"""
|
"""
|
||||||
List all available tools
|
List all available tools
|
||||||
"""
|
"""
|
||||||
|
@ -67,24 +93,116 @@ if MCP_AVAILABLE:
|
||||||
inputSchema=tool.input_schema,
|
inputSchema=tool.input_schema,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
|
||||||
|
)
|
||||||
|
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
|
||||||
|
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
|
||||||
|
if sse_tools is not None:
|
||||||
|
tools.extend(sse_tools)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def handle_call_tool(
|
async def mcp_server_tool_call(
|
||||||
name: str, arguments: Dict[str, Any] | None
|
name: str, arguments: Dict[str, Any] | None
|
||||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||||
"""
|
"""
|
||||||
Call a specific tool with the provided arguments
|
Call a specific tool with the provided arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the tool to call
|
||||||
|
arguments (Dict[str, Any] | None): Arguments to pass to the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If tool not found or arguments missing
|
||||||
|
"""
|
||||||
|
# Validate arguments
|
||||||
|
response = await call_mcp_tool(
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@client
|
||||||
|
async def call_mcp_tool(
|
||||||
|
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||||
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||||
|
"""
|
||||||
|
Call a specific tool with the provided arguments
|
||||||
"""
|
"""
|
||||||
tool = global_mcp_tool_registry.get_tool(name)
|
|
||||||
if not tool:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Request arguments are required"
|
status_code=400, detail="Request arguments are required"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
|
||||||
|
_get_standard_logging_mcp_tool_call(
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||||
|
"litellm_logging_obj", None
|
||||||
|
)
|
||||||
|
if litellm_logging_obj:
|
||||||
|
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
|
||||||
|
standard_logging_mcp_tool_call
|
||||||
|
)
|
||||||
|
litellm_logging_obj.model_call_details["model"] = (
|
||||||
|
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
|
||||||
|
)
|
||||||
|
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
||||||
|
standard_logging_mcp_tool_call.get("mcp_server_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try managed server tool first
|
||||||
|
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
|
||||||
|
return await _handle_managed_mcp_tool(name, arguments)
|
||||||
|
|
||||||
|
# Fall back to local tool registry
|
||||||
|
return await _handle_local_mcp_tool(name, arguments)
|
||||||
|
|
||||||
|
def _get_standard_logging_mcp_tool_call(
|
||||||
|
name: str,
|
||||||
|
arguments: Dict[str, Any],
|
||||||
|
) -> StandardLoggingMCPToolCall:
|
||||||
|
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
|
||||||
|
if mcp_server:
|
||||||
|
mcp_info = mcp_server.mcp_info or {}
|
||||||
|
return StandardLoggingMCPToolCall(
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
mcp_server_name=mcp_info.get("server_name"),
|
||||||
|
mcp_server_logo_url=mcp_info.get("logo_url"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StandardLoggingMCPToolCall(
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_managed_mcp_tool(
|
||||||
|
name: str, arguments: Dict[str, Any]
|
||||||
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||||
|
"""Handle tool execution for managed server tools"""
|
||||||
|
call_tool_result = await global_mcp_server_manager.call_tool(
|
||||||
|
name=name,
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
|
||||||
|
return call_tool_result.content
|
||||||
|
|
||||||
|
async def _handle_local_mcp_tool(
|
||||||
|
name: str, arguments: Dict[str, Any]
|
||||||
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||||
|
"""Handle tool execution for local registry tools"""
|
||||||
|
tool = global_mcp_tool_registry.get_tool(name)
|
||||||
|
if not tool:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = tool.handler(**arguments)
|
result = tool.handler(**arguments)
|
||||||
return [MCPTextContent(text=str(result), type="text")]
|
return [MCPTextContent(text=str(result), type="text")]
|
||||||
|
@ -113,6 +231,74 @@ if MCP_AVAILABLE:
|
||||||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||||
await request.close()
|
await request.close()
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
############ MCP Server REST API Routes #################
|
||||||
|
########################################################
|
||||||
|
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
|
||||||
|
"""
|
||||||
|
List all available tools with information about the server they belong to.
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
Tools:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "create_zap",
|
||||||
|
"description": "Create a new zap",
|
||||||
|
"inputSchema": "tool_input_schema",
|
||||||
|
"mcp_info": {
|
||||||
|
"server_name": "zapier",
|
||||||
|
"logo_url": "https://www.zapier.com/logo.png",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "fetch_data",
|
||||||
|
"description": "Fetch data from a URL",
|
||||||
|
"inputSchema": "tool_input_schema",
|
||||||
|
"mcp_info": {
|
||||||
|
"server_name": "fetch",
|
||||||
|
"logo_url": "https://www.fetch.com/logo.png",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
|
||||||
|
for server in global_mcp_server_manager.mcp_servers:
|
||||||
|
try:
|
||||||
|
tools = await global_mcp_server_manager._get_tools_from_server(server)
|
||||||
|
for tool in tools:
|
||||||
|
list_tools_result.append(
|
||||||
|
ListMCPToolsRestAPIResponseObject(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
inputSchema=tool.inputSchema,
|
||||||
|
mcp_info=server.mcp_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||||
|
continue
|
||||||
|
return list_tools_result
|
||||||
|
|
||||||
|
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def call_tool_rest_api(
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
REST API to call a specific MCP tool with the provided arguments
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
data = await add_litellm_data_to_request(
|
||||||
|
data=data,
|
||||||
|
request=request,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
)
|
||||||
|
return await call_mcp_tool(**data)
|
||||||
|
|
||||||
options = InitializationOptions(
|
options = InitializationOptions(
|
||||||
server_name="litellm-mcp-server",
|
server_name="litellm-mcp-server",
|
||||||
server_version="0.1.0",
|
server_version="0.1.0",
|
||||||
|
|
|
@ -27,6 +27,7 @@ from litellm.types.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderField,
|
ProviderField,
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
|
StandardLoggingMCPToolCall,
|
||||||
StandardLoggingPayloadErrorInformation,
|
StandardLoggingPayloadErrorInformation,
|
||||||
StandardLoggingPayloadStatus,
|
StandardLoggingPayloadStatus,
|
||||||
StandardPassThroughResponseObject,
|
StandardPassThroughResponseObject,
|
||||||
|
@ -1928,6 +1929,7 @@ class SpendLogsMetadata(TypedDict):
|
||||||
] # special param to log k,v pairs to spendlogs for a call
|
] # special param to log k,v pairs to spendlogs for a call
|
||||||
requester_ip_address: Optional[str]
|
requester_ip_address: Optional[str]
|
||||||
applied_guardrails: Optional[List[str]]
|
applied_guardrails: Optional[List[str]]
|
||||||
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
|
||||||
status: StandardLoggingPayloadStatus
|
status: StandardLoggingPayloadStatus
|
||||||
proxy_server_request: Optional[str]
|
proxy_server_request: Optional[str]
|
||||||
batch_models: Optional[List[str]]
|
batch_models: Optional[List[str]]
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: gpt-4o
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: openai/gpt-4o
|
||||||
api_key: fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
|
|
||||||
|
mcp_servers:
|
||||||
|
{
|
||||||
|
"Zapier MCP": {
|
||||||
|
"url": "os.environ/ZAPIER_MCP_SERVER_URL",
|
||||||
|
"mcp_info": {
|
||||||
|
"logo_url": "https://espysys.com/wp-content/uploads/2024/08/zapier-logo.webp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1960,6 +1960,14 @@ class ProxyConfig:
|
||||||
if mcp_tools_config:
|
if mcp_tools_config:
|
||||||
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
||||||
|
|
||||||
|
mcp_servers_config = config.get("mcp_servers", None)
|
||||||
|
if mcp_servers_config:
|
||||||
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||||
|
global_mcp_server_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
global_mcp_server_manager.load_servers_from_config(mcp_servers_config)
|
||||||
|
|
||||||
## CREDENTIALS
|
## CREDENTIALS
|
||||||
credential_list_dict = self.load_credential_list(config=config)
|
credential_list_dict = self.load_credential_list(config=config)
|
||||||
litellm.credential_list = credential_list_dict
|
litellm.credential_list = credential_list_dict
|
||||||
|
|
|
@ -13,7 +13,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||||
from litellm.proxy.utils import PrismaClient, hash_token
|
from litellm.proxy.utils import PrismaClient, hash_token
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload
|
||||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ def _get_spend_logs_metadata(
|
||||||
metadata: Optional[dict],
|
metadata: Optional[dict],
|
||||||
applied_guardrails: Optional[List[str]] = None,
|
applied_guardrails: Optional[List[str]] = None,
|
||||||
batch_models: Optional[List[str]] = None,
|
batch_models: Optional[List[str]] = None,
|
||||||
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||||
) -> SpendLogsMetadata:
|
) -> SpendLogsMetadata:
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
return SpendLogsMetadata(
|
return SpendLogsMetadata(
|
||||||
|
@ -55,6 +56,7 @@ def _get_spend_logs_metadata(
|
||||||
error_information=None,
|
error_information=None,
|
||||||
proxy_server_request=None,
|
proxy_server_request=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
|
mcp_tool_call_metadata=None,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"getting payload for SpendLogs, available keys in metadata: "
|
"getting payload for SpendLogs, available keys in metadata: "
|
||||||
|
@ -71,6 +73,7 @@ def _get_spend_logs_metadata(
|
||||||
)
|
)
|
||||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||||
clean_metadata["batch_models"] = batch_models
|
clean_metadata["batch_models"] = batch_models
|
||||||
|
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
||||||
return clean_metadata
|
return clean_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,6 +203,11 @@ def get_logging_payload( # noqa: PLR0915
|
||||||
if standard_logging_payload is not None
|
if standard_logging_payload is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
mcp_tool_call_metadata=(
|
||||||
|
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
|
||||||
|
if standard_logging_payload is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||||
|
|
16
litellm/types/mcp_server/mcp_server_manager.py
Normal file
16
litellm/types/mcp_server/mcp_server_manager.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class MCPInfo(TypedDict, total=False):
|
||||||
|
server_name: str
|
||||||
|
logo_url: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSSEServer(BaseModel):
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
mcp_info: Optional[MCPInfo] = None
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@ -1644,6 +1644,33 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
||||||
user_api_key_end_user_id: Optional[str]
|
user_api_key_end_user_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingMCPToolCall(TypedDict, total=False):
|
||||||
|
name: str
|
||||||
|
"""
|
||||||
|
Name of the tool to call
|
||||||
|
"""
|
||||||
|
arguments: dict
|
||||||
|
"""
|
||||||
|
Arguments to pass to the tool
|
||||||
|
"""
|
||||||
|
result: dict
|
||||||
|
"""
|
||||||
|
Result of the tool call
|
||||||
|
"""
|
||||||
|
|
||||||
|
mcp_server_name: Optional[str]
|
||||||
|
"""
|
||||||
|
Name of the MCP server that the tool call was made to
|
||||||
|
"""
|
||||||
|
|
||||||
|
mcp_server_logo_url: Optional[str]
|
||||||
|
"""
|
||||||
|
Optional logo URL of the MCP server that the tool call was made to
|
||||||
|
|
||||||
|
(this is to render the logo on the logs page on litellm ui)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class StandardBuiltInToolsParams(TypedDict, total=False):
|
class StandardBuiltInToolsParams(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Standard built-in OpenAItools parameters
|
Standard built-in OpenAItools parameters
|
||||||
|
@ -1674,6 +1701,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||||
requester_ip_address: Optional[str]
|
requester_ip_address: Optional[str]
|
||||||
requester_metadata: Optional[dict]
|
requester_metadata: Optional[dict]
|
||||||
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
|
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
|
||||||
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
|
||||||
applied_guardrails: Optional[List[str]]
|
applied_guardrails: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,11 @@ from mcp.types import Tool as MCPTool
|
||||||
|
|
||||||
from litellm.experimental_mcp_client.tools import (
|
from litellm.experimental_mcp_client.tools import (
|
||||||
_get_function_arguments,
|
_get_function_arguments,
|
||||||
_transform_openai_tool_call_to_mcp_tool_call_request,
|
|
||||||
call_mcp_tool,
|
call_mcp_tool,
|
||||||
call_openai_tool,
|
call_openai_tool,
|
||||||
load_mcp_tools,
|
load_mcp_tools,
|
||||||
transform_mcp_tool_to_openai_tool,
|
transform_mcp_tool_to_openai_tool,
|
||||||
|
transform_openai_tool_call_request_to_mcp_tool_call_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,11 +76,11 @@ def test_transform_mcp_tool_to_openai_tool(mock_mcp_tool):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_transform_openai_tool_call_to_mcp_tool_call_request(mock_mcp_tool):
|
def testtransform_openai_tool_call_request_to_mcp_tool_call_request(mock_mcp_tool):
|
||||||
openai_tool = {
|
openai_tool = {
|
||||||
"function": {"name": "test_tool", "arguments": json.dumps({"test": "value"})}
|
"function": {"name": "test_tool", "arguments": json.dumps({"test": "value"})}
|
||||||
}
|
}
|
||||||
mcp_tool_call_request = _transform_openai_tool_call_to_mcp_tool_call_request(
|
mcp_tool_call_request = transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||||
openai_tool
|
openai_tool
|
||||||
)
|
)
|
||||||
assert mcp_tool_call_request.name == "test_tool"
|
assert mcp_tool_call_request.name == "test_tool"
|
||||||
|
|
|
@ -457,7 +457,7 @@ class TestSpendLogsPayload:
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"user": "",
|
"user": "",
|
||||||
"team_id": "",
|
"team_id": "",
|
||||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
|
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
|
||||||
"cache_key": "Cache OFF",
|
"cache_key": "Cache OFF",
|
||||||
"spend": 0.00022500000000000002,
|
"spend": 0.00022500000000000002,
|
||||||
"total_tokens": 30,
|
"total_tokens": 30,
|
||||||
|
@ -555,7 +555,7 @@ class TestSpendLogsPayload:
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
"user": "",
|
"user": "",
|
||||||
"team_id": "",
|
"team_id": "",
|
||||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||||
"cache_key": "Cache OFF",
|
"cache_key": "Cache OFF",
|
||||||
"spend": 0.01383,
|
"spend": 0.01383,
|
||||||
"total_tokens": 2598,
|
"total_tokens": 2598,
|
||||||
|
@ -651,7 +651,7 @@ class TestSpendLogsPayload:
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
"user": "",
|
"user": "",
|
||||||
"team_id": "",
|
"team_id": "",
|
||||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||||
"cache_key": "Cache OFF",
|
"cache_key": "Cache OFF",
|
||||||
"spend": 0.01383,
|
"spend": 0.01383,
|
||||||
"total_tokens": 2598,
|
"total_tokens": 2598,
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"user": "",
|
"user": "",
|
||||||
"team_id": "",
|
"team_id": "",
|
||||||
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"mcp_tool_call_metadata\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||||
"cache_key": "Cache OFF",
|
"cache_key": "Cache OFF",
|
||||||
"spend": 0.00022500000000000002,
|
"spend": 0.00022500000000000002,
|
||||||
"total_tokens": 30,
|
"total_tokens": 30,
|
||||||
|
|
|
@ -25,7 +25,8 @@
|
||||||
"requester_metadata": null,
|
"requester_metadata": null,
|
||||||
"user_api_key_end_user_id": null,
|
"user_api_key_end_user_id": null,
|
||||||
"prompt_management_metadata": null,
|
"prompt_management_metadata": null,
|
||||||
"applied_guardrails": []
|
"applied_guardrails": [],
|
||||||
|
"mcp_tool_call_metadata": null
|
||||||
},
|
},
|
||||||
"cache_key": null,
|
"cache_key": null,
|
||||||
"response_cost": 0.00022500000000000002,
|
"response_cost": 0.00022500000000000002,
|
||||||
|
|
|
@ -274,6 +274,7 @@ def validate_redacted_message_span_attributes(span):
|
||||||
"metadata.user_api_key_end_user_id",
|
"metadata.user_api_key_end_user_id",
|
||||||
"metadata.user_api_key_user_email",
|
"metadata.user_api_key_user_email",
|
||||||
"metadata.applied_guardrails",
|
"metadata.applied_guardrails",
|
||||||
|
"metadata.mcp_tool_call_metadata",
|
||||||
]
|
]
|
||||||
|
|
||||||
_all_attributes = set(
|
_all_attributes = set(
|
||||||
|
|
35
tests/mcp_tests/test_mcp_server.py
Normal file
35
tests/mcp_tests/test_mcp_server.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Create server parameters for stdio connection
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||||
|
MCPServerManager,
|
||||||
|
MCPSSEServer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
mcp_server_manager = MCPServerManager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="Local only test")
|
||||||
|
async def test_mcp_server_manager():
|
||||||
|
mcp_server_manager.load_servers_from_config(
|
||||||
|
{
|
||||||
|
"zapier_mcp_server": {
|
||||||
|
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tools = await mcp_server_manager.list_tools()
|
||||||
|
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
||||||
|
|
||||||
|
result = await mcp_server_manager.call_tool(
|
||||||
|
name="gmail_send_email", arguments={"body": "Test"}
|
||||||
|
)
|
||||||
|
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
@ -32,6 +32,8 @@ import GuardrailsPanel from "@/components/guardrails";
|
||||||
import TransformRequestPanel from "@/components/transform_request";
|
import TransformRequestPanel from "@/components/transform_request";
|
||||||
import { fetchUserModels } from "@/components/create_key_button";
|
import { fetchUserModels } from "@/components/create_key_button";
|
||||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||||
|
import MCPToolsViewer from "@/components/mcp_tools";
|
||||||
|
|
||||||
function getCookie(name: string) {
|
function getCookie(name: string) {
|
||||||
const cookieValue = document.cookie
|
const cookieValue = document.cookie
|
||||||
.split("; ")
|
.split("; ")
|
||||||
|
@ -347,6 +349,12 @@ export default function CreateKeyPage() {
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
allTeams={teams as Team[] ?? []}
|
allTeams={teams as Team[] ?? []}
|
||||||
/>
|
/>
|
||||||
|
) : page == "mcp-tools" ? (
|
||||||
|
<MCPToolsViewer
|
||||||
|
accessToken={accessToken}
|
||||||
|
userRole={userRole}
|
||||||
|
userID={userID}
|
||||||
|
/>
|
||||||
) : page == "new_usage" ? (
|
) : page == "new_usage" ? (
|
||||||
<NewUsagePage
|
<NewUsagePage
|
||||||
userID={userID}
|
userID={userID}
|
||||||
|
|
|
@ -20,7 +20,8 @@ import {
|
||||||
SafetyOutlined,
|
SafetyOutlined,
|
||||||
ExperimentOutlined,
|
ExperimentOutlined,
|
||||||
ThunderboltOutlined,
|
ThunderboltOutlined,
|
||||||
LockOutlined
|
LockOutlined,
|
||||||
|
ToolOutlined,
|
||||||
} from '@ant-design/icons';
|
} from '@ant-design/icons';
|
||||||
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess } from '../utils/roles';
|
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess } from '../utils/roles';
|
||||||
|
|
||||||
|
@ -69,6 +70,7 @@ const menuItems: MenuItem[] = [
|
||||||
{ key: "10", page: "budgets", label: "Budgets", icon: <BankOutlined />, roles: all_admin_roles },
|
{ key: "10", page: "budgets", label: "Budgets", icon: <BankOutlined />, roles: all_admin_roles },
|
||||||
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
|
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
|
||||||
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: all_admin_roles },
|
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: all_admin_roles },
|
||||||
|
{ key: "18", page: "mcp-tools", label: "MCP Tools", icon: <ToolOutlined />, roles: all_admin_roles },
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
import React from 'react';
|
||||||
|
|
||||||
|
const codeString = `import asyncio
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from openai.types.chat import ChatCompletionUserMessageParam
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from litellm.experimental_mcp_client.tools import (
|
||||||
|
transform_mcp_tool_to_openai_tool,
|
||||||
|
transform_openai_tool_call_request_to_mcp_tool_call_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Initialize clients
|
||||||
|
client = AsyncOpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://localhost:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to MCP
|
||||||
|
async with sse_client("http://localhost:4000/mcp/") as (read, write):
|
||||||
|
async with ClientSession(read, write) as session:
|
||||||
|
await session.initialize()
|
||||||
|
mcp_tools = await session.list_tools()
|
||||||
|
print("List of MCP tools for MCP server:", mcp_tools.tools)
|
||||||
|
|
||||||
|
# Create message
|
||||||
|
messages = [
|
||||||
|
ChatCompletionUserMessageParam(
|
||||||
|
content="Send an email about LiteLLM supporting MCP",
|
||||||
|
role="user"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Request with tools
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=messages,
|
||||||
|
tools=[transform_mcp_tool_to_openai_tool(tool) for tool in mcp_tools.tools],
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle tool call
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
tool_call = response.choices[0].message.tool_calls[0]
|
||||||
|
if tool_call:
|
||||||
|
# Convert format
|
||||||
|
mcp_call = transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||||
|
openai_tool=tool_call.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
result = await session.call_tool(
|
||||||
|
name=mcp_call.name,
|
||||||
|
arguments=mcp_call.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Result:", result)
|
||||||
|
|
||||||
|
# Run it
|
||||||
|
asyncio.run(main())`;
|
||||||
|
|
||||||
|
export const CodeExample: React.FC = () => {
|
||||||
|
return (
|
||||||
|
<div className="bg-white rounded-lg shadow h-full">
|
||||||
|
<div className="border-b px-4 py-3">
|
||||||
|
<h3 className="text-base font-medium text-gray-900">Using MCP Tools</h3>
|
||||||
|
</div>
|
||||||
|
<div className="p-4">
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<div className="flex-1">
|
||||||
|
<div className="text-sm font-medium text-gray-700">Python integration</div>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
className="text-xs bg-gray-100 hover:bg-gray-200 text-gray-600 px-2 py-1 rounded-md transition-colors"
|
||||||
|
onClick={() => {
|
||||||
|
navigator.clipboard.writeText(codeString);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Copy
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="overflow-auto rounded-md bg-gray-50 border" style={{ maxHeight: "calc(100vh - 280px)" }}>
|
||||||
|
<pre className="p-3 text-xs font-mono text-gray-800 whitespace-pre overflow-x-auto">
|
||||||
|
{codeString}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
309
ui/litellm-dashboard/src/components/mcp_tools/columns.tsx
Normal file
309
ui/litellm-dashboard/src/components/mcp_tools/columns.tsx
Normal file
|
@ -0,0 +1,309 @@
|
||||||
|
import React from "react";
|
||||||
|
import { ColumnDef } from "@tanstack/react-table";
|
||||||
|
import { MCPTool, InputSchema } from "./types";
|
||||||
|
import { Button } from "@tremor/react"
|
||||||
|
|
||||||
|
export const columns: ColumnDef<MCPTool>[] = [
|
||||||
|
{
|
||||||
|
accessorKey: "mcp_info.server_name",
|
||||||
|
header: "Provider",
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const serverName = row.original.mcp_info.server_name;
|
||||||
|
const logoUrl = row.original.mcp_info.logo_url;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
{logoUrl && (
|
||||||
|
<img
|
||||||
|
src={logoUrl}
|
||||||
|
alt={`${serverName} logo`}
|
||||||
|
className="h-5 w-5 object-contain"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<span className="font-medium">{serverName}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "name",
|
||||||
|
header: "Tool Name",
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const name = row.getValue("name") as string;
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<span className="font-mono text-sm">{name}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: "description",
|
||||||
|
header: "Description",
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const description = row.getValue("description") as string;
|
||||||
|
return (
|
||||||
|
<div className="max-w-md">
|
||||||
|
<span className="text-sm text-gray-700">{description}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "actions",
|
||||||
|
header: "Actions",
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const tool = row.original;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center space-x-2">
|
||||||
|
<Button
|
||||||
|
size="xs"
|
||||||
|
variant="light"
|
||||||
|
className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5 text-left overflow-hidden truncate max-w-[200px]"
|
||||||
|
onClick={() => {
|
||||||
|
if (typeof row.original.onToolSelect === 'function') {
|
||||||
|
row.original.onToolSelect(tool);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Test Tool
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// Tool Panel component to display when a tool is selected
|
||||||
|
export function ToolTestPanel({
|
||||||
|
tool,
|
||||||
|
onSubmit,
|
||||||
|
isLoading,
|
||||||
|
result,
|
||||||
|
error,
|
||||||
|
onClose
|
||||||
|
}: {
|
||||||
|
tool: MCPTool;
|
||||||
|
onSubmit: (args: Record<string, any>) => void;
|
||||||
|
isLoading: boolean;
|
||||||
|
result: any | null;
|
||||||
|
error: Error | null;
|
||||||
|
onClose: () => void;
|
||||||
|
}) {
|
||||||
|
const [formState, setFormState] = React.useState<Record<string, any>>({});
|
||||||
|
|
||||||
|
// Create a placeholder schema if we only have the "tool_input_schema" string
|
||||||
|
const schema: InputSchema = React.useMemo(() => {
|
||||||
|
if (typeof tool.inputSchema === 'string') {
|
||||||
|
// Default schema with a single text field
|
||||||
|
return {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
input: {
|
||||||
|
type: "string",
|
||||||
|
description: "Input for this tool"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ["input"]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return tool.inputSchema as InputSchema;
|
||||||
|
}, [tool.inputSchema]);
|
||||||
|
|
||||||
|
const handleInputChange = (key: string, value: any) => {
|
||||||
|
setFormState(prev => ({
|
||||||
|
...prev,
|
||||||
|
[key]: value
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSubmit = (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
onSubmit(formState);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="bg-white rounded-lg shadow-lg border p-6 max-w-4xl w-full">
|
||||||
|
<div className="flex justify-between items-start mb-4">
|
||||||
|
<div>
|
||||||
|
<h2 className="text-xl font-bold">Test Tool: <span className="font-mono">{tool.name}</span></h2>
|
||||||
|
<p className="text-gray-600">{tool.description}</p>
|
||||||
|
<p className="text-sm text-gray-500 mt-1">Provider: {tool.mcp_info.server_name}</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={onClose}
|
||||||
|
className="p-1 rounded-full hover:bg-gray-200"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="20"
|
||||||
|
height="20"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
>
|
||||||
|
<line x1="18" y1="6" x2="6" y2="18"></line>
|
||||||
|
<line x1="6" y1="6" x2="18" y2="18"></line>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
|
{/* Form Section */}
|
||||||
|
<div className="bg-gray-50 p-4 rounded-lg">
|
||||||
|
<h3 className="font-medium mb-4">Input Parameters</h3>
|
||||||
|
<form onSubmit={handleSubmit}>
|
||||||
|
{typeof tool.inputSchema === 'string' ? (
|
||||||
|
<div className="mb-4">
|
||||||
|
<p className="text-xs text-gray-500 mb-1">This tool uses a dynamic input schema.</p>
|
||||||
|
<div className="mb-4">
|
||||||
|
<label className="block text-sm font-medium text-gray-700 mb-1">
|
||||||
|
Input <span className="text-red-500">*</span>
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={formState.input || ""}
|
||||||
|
onChange={(e) => handleInputChange("input", e.target.value)}
|
||||||
|
required
|
||||||
|
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
Object.entries(schema.properties).map(([key, prop]) => (
|
||||||
|
<div key={key} className="mb-4">
|
||||||
|
<label className="block text-sm font-medium text-gray-700 mb-1">
|
||||||
|
{key}{" "}
|
||||||
|
{schema.required?.includes(key) && (
|
||||||
|
<span className="text-red-500">*</span>
|
||||||
|
)}
|
||||||
|
</label>
|
||||||
|
{prop.description && (
|
||||||
|
<p className="text-xs text-gray-500 mb-1">{prop.description}</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Render appropriate input based on type */}
|
||||||
|
{prop.type === "string" && (
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={formState[key] || ""}
|
||||||
|
onChange={(e) => handleInputChange(key, e.target.value)}
|
||||||
|
required={schema.required?.includes(key)}
|
||||||
|
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{prop.type === "number" && (
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
value={formState[key] || ""}
|
||||||
|
onChange={(e) => handleInputChange(key, parseFloat(e.target.value))}
|
||||||
|
required={schema.required?.includes(key)}
|
||||||
|
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{prop.type === "boolean" && (
|
||||||
|
<div className="flex items-center">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={formState[key] || false}
|
||||||
|
onChange={(e) => handleInputChange(key, e.target.checked)}
|
||||||
|
className="h-4 w-4 text-blue-600 focus:ring-blue-500 border-gray-300 rounded"
|
||||||
|
/>
|
||||||
|
<span className="ml-2 text-sm text-gray-600">Enable</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="mt-6">
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
disabled={isLoading}
|
||||||
|
className="w-full px-4 py-2 border border-transparent rounded-md shadow-sm text-sm font-medium text-white"
|
||||||
|
>
|
||||||
|
{isLoading ? "Calling..." : "Call Tool"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Result Section */}
|
||||||
|
<div className="bg-gray-50 p-4 rounded-lg overflow-auto max-h-[500px]">
|
||||||
|
<h3 className="font-medium mb-4">Result</h3>
|
||||||
|
|
||||||
|
{isLoading && (
|
||||||
|
<div className="flex justify-center items-center py-8">
|
||||||
|
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-blue-700"></div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="bg-red-50 border border-red-200 text-red-800 px-4 py-3 rounded-md">
|
||||||
|
<p className="font-medium">Error</p>
|
||||||
|
<pre className="mt-2 text-xs overflow-auto whitespace-pre-wrap">{error.message}</pre>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{result && !isLoading && !error && (
|
||||||
|
<div>
|
||||||
|
{result.map((content: any, idx: number) => (
|
||||||
|
<div key={idx} className="mb-4">
|
||||||
|
{content.type === "text" && (
|
||||||
|
<div className="bg-white border p-3 rounded-md">
|
||||||
|
<p className="whitespace-pre-wrap text-sm">{content.text}</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{content.type === "image" && content.url && (
|
||||||
|
<div className="bg-white border p-3 rounded-md">
|
||||||
|
<img src={content.url} alt="Tool result" className="max-w-full h-auto rounded" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{content.type === "embedded_resource" && (
|
||||||
|
<div className="bg-white border p-3 rounded-md">
|
||||||
|
<p className="text-sm font-medium">Embedded Resource</p>
|
||||||
|
<p className="text-xs text-gray-500">Type: {content.resource_type}</p>
|
||||||
|
{content.url && (
|
||||||
|
<a
|
||||||
|
href={content.url}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="text-sm text-blue-600 hover:underline"
|
||||||
|
>
|
||||||
|
View Resource
|
||||||
|
</a>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<div className="mt-2">
|
||||||
|
<details className="text-xs">
|
||||||
|
<summary className="cursor-pointer text-gray-500 hover:text-gray-700">Raw JSON Response</summary>
|
||||||
|
<pre className="mt-2 bg-gray-100 p-2 rounded-md overflow-auto max-h-[300px]">
|
||||||
|
{JSON.stringify(result, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!result && !isLoading && !error && (
|
||||||
|
<div className="text-center py-8 text-gray-500">
|
||||||
|
<p>The result will appear here after you call the tool.</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
172
ui/litellm-dashboard/src/components/mcp_tools/index.tsx
Normal file
172
ui/litellm-dashboard/src/components/mcp_tools/index.tsx
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
import React, { useState } from 'react';
|
||||||
|
import { useQuery, useMutation } from '@tanstack/react-query';
|
||||||
|
import { DataTable } from '../view_logs/table';
|
||||||
|
import { columns, ToolTestPanel } from './columns';
|
||||||
|
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from './types';
|
||||||
|
import { listMCPTools, callMCPTool } from '../networking';
|
||||||
|
|
||||||
|
// Wrapper to handle the type mismatch between MCPTool and DataTable's expected type
|
||||||
|
function DataTableWrapper({
|
||||||
|
columns,
|
||||||
|
data,
|
||||||
|
isLoading,
|
||||||
|
}: {
|
||||||
|
columns: any;
|
||||||
|
data: MCPTool[];
|
||||||
|
isLoading: boolean;
|
||||||
|
}) {
|
||||||
|
// Create a dummy renderSubComponent and getRowCanExpand function
|
||||||
|
const renderSubComponent = () => <div />;
|
||||||
|
const getRowCanExpand = () => false;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DataTable
|
||||||
|
columns={columns as any}
|
||||||
|
data={data as any}
|
||||||
|
isLoading={isLoading}
|
||||||
|
renderSubComponent={renderSubComponent}
|
||||||
|
getRowCanExpand={getRowCanExpand}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function MCPToolsViewer({
|
||||||
|
accessToken,
|
||||||
|
userRole,
|
||||||
|
userID,
|
||||||
|
}: MCPToolsViewerProps) {
|
||||||
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||||
|
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
|
||||||
|
const [toolError, setToolError] = useState<Error | null>(null);
|
||||||
|
|
||||||
|
// Query to fetch MCP tools
|
||||||
|
const { data: mcpTools, isLoading: isLoadingTools } = useQuery({
|
||||||
|
queryKey: ['mcpTools'],
|
||||||
|
queryFn: () => {
|
||||||
|
if (!accessToken) throw new Error('Access Token required');
|
||||||
|
return listMCPTools(accessToken);
|
||||||
|
},
|
||||||
|
enabled: !!accessToken,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mutation for calling a tool
|
||||||
|
const { mutate: executeTool, isPending: isCallingTool } = useMutation({
|
||||||
|
mutationFn: (args: { tool: MCPTool; arguments: Record<string, any> }) => {
|
||||||
|
if (!accessToken) throw new Error('Access Token required');
|
||||||
|
return callMCPTool(
|
||||||
|
accessToken,
|
||||||
|
args.tool.name,
|
||||||
|
args.arguments
|
||||||
|
);
|
||||||
|
},
|
||||||
|
onSuccess: (data) => {
|
||||||
|
setToolResult(data);
|
||||||
|
setToolError(null);
|
||||||
|
},
|
||||||
|
onError: (error: Error) => {
|
||||||
|
setToolError(error);
|
||||||
|
setToolResult(null);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add onToolSelect handler to each tool
|
||||||
|
const toolsData = React.useMemo(() => {
|
||||||
|
if (!mcpTools) return [];
|
||||||
|
|
||||||
|
return mcpTools.map((tool: MCPTool) => ({
|
||||||
|
...tool,
|
||||||
|
onToolSelect: (tool: MCPTool) => {
|
||||||
|
setSelectedTool(tool);
|
||||||
|
setToolResult(null);
|
||||||
|
setToolError(null);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}, [mcpTools]);
|
||||||
|
|
||||||
|
// Filter tools based on search term
|
||||||
|
const filteredTools = React.useMemo(() => {
|
||||||
|
return toolsData.filter((tool: MCPTool) => {
|
||||||
|
const searchLower = searchTerm.toLowerCase();
|
||||||
|
return (
|
||||||
|
tool.name.toLowerCase().includes(searchLower) ||
|
||||||
|
tool.description.toLowerCase().includes(searchLower) ||
|
||||||
|
tool.mcp_info.server_name.toLowerCase().includes(searchLower)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}, [toolsData, searchTerm]);
|
||||||
|
|
||||||
|
// Handle tool call submission
|
||||||
|
const handleToolSubmit = (args: Record<string, any>) => {
|
||||||
|
if (!selectedTool) return;
|
||||||
|
|
||||||
|
executeTool({
|
||||||
|
tool: selectedTool,
|
||||||
|
arguments: args,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!accessToken || !userRole || !userID) {
|
||||||
|
return <div className="p-6 text-center text-gray-500">Missing required authentication parameters.</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full p-6">
|
||||||
|
<div className="flex items-center justify-between mb-4">
|
||||||
|
<h1 className="text-xl font-semibold">MCP Tools</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-white rounded-lg shadow">
|
||||||
|
<div className="border-b px-6 py-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="relative w-64">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
placeholder="Search tools..."
|
||||||
|
className="w-full px-3 py-2 pl-8 border rounded-md text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
||||||
|
value={searchTerm}
|
||||||
|
onChange={(e) => setSearchTerm(e.target.value)}
|
||||||
|
/>
|
||||||
|
<svg
|
||||||
|
className="absolute left-2.5 top-2.5 h-4 w-4 text-gray-500"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
strokeWidth={2}
|
||||||
|
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<div className="text-sm text-gray-500">
|
||||||
|
{filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""} available
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DataTableWrapper
|
||||||
|
columns={columns}
|
||||||
|
data={filteredTools}
|
||||||
|
isLoading={isLoadingTools}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Tool Test Panel - Show when a tool is selected */}
|
||||||
|
{selectedTool && (
|
||||||
|
<div className="fixed inset-0 bg-gray-800 bg-opacity-75 flex items-center justify-center z-50 p-4">
|
||||||
|
<ToolTestPanel
|
||||||
|
tool={selectedTool}
|
||||||
|
onSubmit={handleToolSubmit}
|
||||||
|
isLoading={isCallingTool}
|
||||||
|
result={toolResult}
|
||||||
|
error={toolError}
|
||||||
|
onClose={() => setSelectedTool(null)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
71
ui/litellm-dashboard/src/components/mcp_tools/types.tsx
Normal file
71
ui/litellm-dashboard/src/components/mcp_tools/types.tsx
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
// Define the structure for tool input schema properties
|
||||||
|
export interface InputSchemaProperty {
|
||||||
|
type: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the structure for the input schema of a tool
|
||||||
|
export interface InputSchema {
|
||||||
|
type: "object";
|
||||||
|
properties: Record<string, InputSchemaProperty>;
|
||||||
|
required?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define MCP provider info
|
||||||
|
export interface MCPInfo {
|
||||||
|
server_name: string;
|
||||||
|
logo_url?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the structure for a single MCP tool
|
||||||
|
export interface MCPTool {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
inputSchema: InputSchema | string; // API returns string "tool_input_schema" or the actual schema
|
||||||
|
mcp_info: MCPInfo;
|
||||||
|
// Function to select a tool (added in the component)
|
||||||
|
onToolSelect?: (tool: MCPTool) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the response structure for the listMCPTools endpoint - now a flat array
|
||||||
|
export type ListMCPToolsResponse = MCPTool[];
|
||||||
|
|
||||||
|
// Define the argument structure for calling an MCP tool
|
||||||
|
export interface CallMCPToolArgs {
|
||||||
|
name: string;
|
||||||
|
arguments: Record<string, any> | null;
|
||||||
|
server_name?: string; // Now using server_name from mcp_info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the possible content types in the response
|
||||||
|
export interface MCPTextContent {
|
||||||
|
type: "text";
|
||||||
|
text: string;
|
||||||
|
annotations?: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MCPImageContent {
|
||||||
|
type: "image";
|
||||||
|
url?: string;
|
||||||
|
data?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MCPEmbeddedResource {
|
||||||
|
type: "embedded_resource";
|
||||||
|
resource_type?: string;
|
||||||
|
url?: string;
|
||||||
|
data?: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the union type for the content array in the response
|
||||||
|
export type MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource;
|
||||||
|
|
||||||
|
// Define the response structure for the callMCPTool endpoint
|
||||||
|
export type CallMCPToolResponse = MCPContent[];
|
||||||
|
|
||||||
|
// Props for the main component
|
||||||
|
export interface MCPToolsViewerProps {
|
||||||
|
accessToken: string | null;
|
||||||
|
userRole: string | null;
|
||||||
|
userID: string | null;
|
||||||
|
}
|
|
@ -4084,3 +4084,73 @@ export const updateInternalUserSettings = async (accessToken: string, settings:
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
export const listMCPTools = async (accessToken: string) => {
|
||||||
|
try {
|
||||||
|
// Construct base URL
|
||||||
|
let url = proxyBaseUrl
|
||||||
|
? `${proxyBaseUrl}/mcp/tools/list`
|
||||||
|
: `/mcp/tools/list`;
|
||||||
|
|
||||||
|
console.log("Fetching MCP tools from:", url);
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "GET",
|
||||||
|
headers: {
|
||||||
|
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.text();
|
||||||
|
handleError(errorData);
|
||||||
|
throw new Error("Network response was not ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("Fetched MCP tools:", data);
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to fetch MCP tools:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
export const callMCPTool = async (accessToken: string, toolName: string, toolArguments: Record<string, any>) => {
|
||||||
|
try {
|
||||||
|
// Construct base URL
|
||||||
|
let url = proxyBaseUrl
|
||||||
|
? `${proxyBaseUrl}/mcp/tools/call`
|
||||||
|
: `/mcp/tools/call`;
|
||||||
|
|
||||||
|
console.log("Calling MCP tool:", toolName, "with arguments:", toolArguments);
|
||||||
|
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
name: toolName,
|
||||||
|
arguments: toolArguments,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.text();
|
||||||
|
handleError(errorData);
|
||||||
|
throw new Error("Network response was not ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log("MCP tool call response:", data);
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to call MCP tool:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
};
|
|
@ -8,6 +8,19 @@ import { Tooltip } from "antd";
|
||||||
import { TimeCell } from "./time_cell";
|
import { TimeCell } from "./time_cell";
|
||||||
import { Button } from "@tremor/react";
|
import { Button } from "@tremor/react";
|
||||||
|
|
||||||
|
// Helper to get the appropriate logo URL
|
||||||
|
const getLogoUrl = (
|
||||||
|
row: LogEntry,
|
||||||
|
provider: string
|
||||||
|
) => {
|
||||||
|
// Check if mcp_tool_call_metadata exists and contains mcp_server_logo_url
|
||||||
|
if (row.metadata?.mcp_tool_call_metadata?.mcp_server_logo_url) {
|
||||||
|
return row.metadata.mcp_tool_call_metadata.mcp_server_logo_url;
|
||||||
|
}
|
||||||
|
// Fall back to default provider logo
|
||||||
|
return provider ? getProviderLogoAndName(provider).logo : '';
|
||||||
|
};
|
||||||
|
|
||||||
export type LogEntry = {
|
export type LogEntry = {
|
||||||
request_id: string;
|
request_id: string;
|
||||||
api_key: string;
|
api_key: string;
|
||||||
|
@ -177,7 +190,7 @@ export const columns: ColumnDef<LogEntry>[] = [
|
||||||
<div className="flex items-center space-x-2">
|
<div className="flex items-center space-x-2">
|
||||||
{provider && (
|
{provider && (
|
||||||
<img
|
<img
|
||||||
src={getProviderLogoAndName(provider).logo}
|
src={getLogoUrl(row, provider)}
|
||||||
alt=""
|
alt=""
|
||||||
className="w-4 h-4"
|
className="w-4 h-4"
|
||||||
onError={(e) => {
|
onError={(e) => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue