mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge pull request #9642 from BerriAI/litellm_mcp_improvements_expose_sse_urls
[Feat] - MCP improvements, add support for using SSE MCP servers
This commit is contained in:
parent
0865e52db3
commit
5daf40ce24
25 changed files with 1210 additions and 29 deletions
|
@ -414,6 +414,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
|
|||
|
||||
########################### Logging Callback Constants ###########################
|
||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||
|
||||
########################### LiteLLM Proxy Specific Constants ###########################
|
||||
########################################################################################
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import List, Literal, Union
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
|
@ -76,8 +76,8 @@ def _get_function_arguments(function: FunctionDefinition) -> dict:
|
|||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
|
@ -100,8 +100,10 @@ async def call_openai_tool(
|
|||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
mcp_tool_call_request_params = (
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
|
|
|
@ -67,6 +67,7 @@ from litellm.types.utils import (
|
|||
StandardCallbackDynamicParams,
|
||||
StandardLoggingAdditionalHeaders,
|
||||
StandardLoggingHiddenParams,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingMetadata,
|
||||
StandardLoggingModelCostFailureDebugInformation,
|
||||
StandardLoggingModelInformation,
|
||||
|
@ -1099,7 +1100,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
elif isinstance(result, dict): # pass-through endpoints
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
|
@ -3114,6 +3115,7 @@ class StandardLoggingPayloadSetup:
|
|||
litellm_params: Optional[dict] = None,
|
||||
prompt_integration: Optional[str] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -3160,6 +3162,7 @@ class StandardLoggingPayloadSetup:
|
|||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=prompt_management_metadata,
|
||||
applied_guardrails=applied_guardrails,
|
||||
mcp_tool_call_metadata=mcp_tool_call_metadata,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3486,6 +3489,7 @@ def get_standard_logging_object_payload(
|
|||
litellm_params=litellm_params,
|
||||
prompt_integration=kwargs.get("prompt_integration", None),
|
||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
|
||||
)
|
||||
|
||||
_request_body = proxy_server_request.get("body", {})
|
||||
|
@ -3626,6 +3630,7 @@ def get_standard_logging_metadata(
|
|||
user_api_key_end_user_id=None,
|
||||
prompt_management_metadata=None,
|
||||
applied_guardrails=None,
|
||||
mcp_tool_call_metadata=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
|
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
"""
|
||||
MCP Client Manager
|
||||
|
||||
This class is responsible for managing MCP SSE clients.
|
||||
|
||||
This is a Proxy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
|
||||
|
||||
|
||||
class MCPServerManager:
|
||||
def __init__(self):
|
||||
self.mcp_servers: List[MCPSSEServer] = []
|
||||
"""
|
||||
eg.
|
||||
[
|
||||
{
|
||||
"name": "zapier_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
},
|
||||
{
|
||||
"name": "google_drive_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
|
||||
"""
|
||||
{
|
||||
"gmail_send_email": "zapier_mcp_server",
|
||||
}
|
||||
"""
|
||||
|
||||
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
|
||||
"""
|
||||
Load the MCP Servers from the config
|
||||
"""
|
||||
for server_name, server_config in mcp_servers_config.items():
|
||||
_mcp_info: dict = server_config.get("mcp_info", None) or {}
|
||||
mcp_info = MCPInfo(**_mcp_info)
|
||||
mcp_info["server_name"] = server_name
|
||||
self.mcp_servers.append(
|
||||
MCPSSEServer(
|
||||
name=server_name,
|
||||
url=server_config["url"],
|
||||
mcp_info=mcp_info,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
|
||||
)
|
||||
|
||||
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""
|
||||
List all tools available across all MCP Servers.
|
||||
|
||||
Returns:
|
||||
List[MCPTool]: Combined list of tools from all servers
|
||||
"""
|
||||
list_tools_result: List[MCPTool] = []
|
||||
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
|
||||
|
||||
for server in self.mcp_servers:
|
||||
tools = await self._get_tools_from_server(server)
|
||||
list_tools_result.extend(tools)
|
||||
|
||||
return list_tools_result
|
||||
|
||||
async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
|
||||
"""
|
||||
Helper method to get tools from a single MCP server.
|
||||
|
||||
Args:
|
||||
server (MCPSSEServer): The server to query tools from
|
||||
|
||||
Returns:
|
||||
List[MCPTool]: List of tools available on the server
|
||||
"""
|
||||
verbose_logger.debug(f"Connecting to url: {server.url}")
|
||||
|
||||
async with sse_client(url=server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
|
||||
|
||||
# Update tool to server mapping
|
||||
for tool in tools_result.tools:
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||
|
||||
return tools_result.tools
|
||||
|
||||
def initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||
"""
|
||||
On startup, initialize the tool name to MCP server name mapping
|
||||
"""
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
asyncio.create_task(
|
||||
self._initialize_tool_name_to_mcp_server_name_mapping()
|
||||
)
|
||||
except RuntimeError as e: # no running event loop
|
||||
verbose_logger.exception(
|
||||
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
|
||||
)
|
||||
|
||||
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||
"""
|
||||
Call list_tools for each server and update the tool name to MCP server name mapping
|
||||
"""
|
||||
for server in self.mcp_servers:
|
||||
tools = await self._get_tools_from_server(server)
|
||||
for tool in tools:
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||
|
||||
async def call_tool(self, name: str, arguments: Dict[str, Any]):
|
||||
"""
|
||||
Call a tool with the given name and arguments
|
||||
"""
|
||||
mcp_server = self._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server is None:
|
||||
raise ValueError(f"Tool {name} not found")
|
||||
async with sse_client(url=mcp_server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
return await session.call_tool(name, arguments)
|
||||
|
||||
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
|
||||
"""
|
||||
Get the MCP Server from the tool name
|
||||
"""
|
||||
if tool_name in self.tool_name_to_mcp_server_name_mapping:
|
||||
for server in self.mcp_servers:
|
||||
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
|
||||
return server
|
||||
return None
|
||||
|
||||
|
||||
global_mcp_server_manager: MCPServerManager = MCPServerManager()
|
|
@ -3,14 +3,21 @@ LiteLLM MCP Server Routes
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ConfigDict, ValidationError
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_TOOL_NAME_PREFIX
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||
from litellm.utils import client
|
||||
|
||||
# Check if MCP is available
|
||||
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
||||
|
@ -36,9 +43,23 @@ if MCP_AVAILABLE:
|
|||
from mcp.types import TextContent as MCPTextContent
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from .mcp_server_manager import global_mcp_server_manager
|
||||
from .sse_transport import SseServerTransport
|
||||
from .tool_registry import global_mcp_tool_registry
|
||||
|
||||
######################################################
|
||||
############ MCP Tools List REST API Response Object #
|
||||
# Defined here because we don't want to add `mcp` as a
|
||||
# required dependency for `litellm` pip package
|
||||
######################################################
|
||||
class ListMCPToolsRestAPIResponseObject(MCPTool):
|
||||
"""
|
||||
Object returned by the /tools/list REST API route.
|
||||
"""
|
||||
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
########################################################
|
||||
############ Initialize the MCP Server #################
|
||||
########################################################
|
||||
|
@ -52,9 +73,14 @@ if MCP_AVAILABLE:
|
|||
########################################################
|
||||
############### MCP Server Routes #######################
|
||||
########################################################
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
return await _list_mcp_tools()
|
||||
|
||||
async def _list_mcp_tools() -> List[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
|
@ -67,24 +93,116 @@ if MCP_AVAILABLE:
|
|||
inputSchema=tool.input_schema,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
|
||||
)
|
||||
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
|
||||
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
|
||||
if sse_tools is not None:
|
||||
tools.extend(sse_tools)
|
||||
return tools
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
async def mcp_server_tool_call(
|
||||
name: str, arguments: Dict[str, Any] | None
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
|
||||
Args:
|
||||
name (str): Name of the tool to call
|
||||
arguments (Dict[str, Any] | None): Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
|
||||
|
||||
Raises:
|
||||
HTTPException: If tool not found or arguments missing
|
||||
"""
|
||||
# Validate arguments
|
||||
response = await call_mcp_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return response
|
||||
|
||||
@client
|
||||
async def call_mcp_tool(
|
||||
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
"""
|
||||
tool = global_mcp_tool_registry.get_tool(name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||
if arguments is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Request arguments are required"
|
||||
)
|
||||
|
||||
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
|
||||
_get_standard_logging_mcp_tool_call(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
if litellm_logging_obj:
|
||||
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
|
||||
standard_logging_mcp_tool_call
|
||||
)
|
||||
litellm_logging_obj.model_call_details["model"] = (
|
||||
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
|
||||
)
|
||||
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
||||
standard_logging_mcp_tool_call.get("mcp_server_name")
|
||||
)
|
||||
|
||||
# Try managed server tool first
|
||||
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
|
||||
return await _handle_managed_mcp_tool(name, arguments)
|
||||
|
||||
# Fall back to local tool registry
|
||||
return await _handle_local_mcp_tool(name, arguments)
|
||||
|
||||
def _get_standard_logging_mcp_tool_call(
|
||||
name: str,
|
||||
arguments: Dict[str, Any],
|
||||
) -> StandardLoggingMCPToolCall:
|
||||
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server:
|
||||
mcp_info = mcp_server.mcp_info or {}
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
mcp_server_name=mcp_info.get("server_name"),
|
||||
mcp_server_logo_url=mcp_info.get("logo_url"),
|
||||
)
|
||||
else:
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
async def _handle_managed_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for managed server tools"""
|
||||
call_tool_result = await global_mcp_server_manager.call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
|
||||
return call_tool_result.content
|
||||
|
||||
async def _handle_local_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for local registry tools"""
|
||||
tool = global_mcp_tool_registry.get_tool(name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||
|
||||
try:
|
||||
result = tool.handler(**arguments)
|
||||
return [MCPTextContent(text=str(result), type="text")]
|
||||
|
@ -113,6 +231,74 @@ if MCP_AVAILABLE:
|
|||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||
await request.close()
|
||||
|
||||
########################################################
|
||||
############ MCP Server REST API Routes #################
|
||||
########################################################
|
||||
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
|
||||
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
|
||||
"""
|
||||
List all available tools with information about the server they belong to.
|
||||
|
||||
Example response:
|
||||
Tools:
|
||||
[
|
||||
{
|
||||
"name": "create_zap",
|
||||
"description": "Create a new zap",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "zapier",
|
||||
"logo_url": "https://www.zapier.com/logo.png",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fetch_data",
|
||||
"description": "Fetch data from a URL",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "fetch",
|
||||
"logo_url": "https://www.fetch.com/logo.png",
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
|
||||
for server in global_mcp_server_manager.mcp_servers:
|
||||
try:
|
||||
tools = await global_mcp_server_manager._get_tools_from_server(server)
|
||||
for tool in tools:
|
||||
list_tools_result.append(
|
||||
ListMCPToolsRestAPIResponseObject(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.inputSchema,
|
||||
mcp_info=server.mcp_info,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||
continue
|
||||
return list_tools_result
|
||||
|
||||
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
|
||||
async def call_tool_rest_api(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
REST API to call a specific MCP tool with the provided arguments
|
||||
"""
|
||||
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config
|
||||
|
||||
data = await request.json()
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
return await call_mcp_tool(**data)
|
||||
|
||||
options = InitializationOptions(
|
||||
server_name="litellm-mcp-server",
|
||||
server_version="0.1.0",
|
||||
|
|
|
@ -27,6 +27,7 @@ from litellm.types.utils import (
|
|||
ModelResponse,
|
||||
ProviderField,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
StandardLoggingPayloadStatus,
|
||||
StandardPassThroughResponseObject,
|
||||
|
@ -1913,6 +1914,7 @@ class SpendLogsMetadata(TypedDict):
|
|||
] # special param to log k,v pairs to spendlogs for a call
|
||||
requester_ip_address: Optional[str]
|
||||
applied_guardrails: Optional[List[str]]
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
|
||||
status: StandardLoggingPayloadStatus
|
||||
proxy_server_request: Optional[str]
|
||||
batch_models: Optional[List[str]]
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: True
|
|
@ -2165,6 +2165,14 @@ class ProxyConfig:
|
|||
if mcp_tools_config:
|
||||
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
||||
|
||||
mcp_servers_config = config.get("mcp_servers", None)
|
||||
if mcp_servers_config:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
global_mcp_server_manager.load_servers_from_config(mcp_servers_config)
|
||||
|
||||
## CREDENTIALS
|
||||
credential_list_dict = self.load_credential_list(config=config)
|
||||
litellm.credential_list = credential_list_dict
|
||||
|
|
|
@ -13,7 +13,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
|
@ -38,6 +38,7 @@ def _get_spend_logs_metadata(
|
|||
metadata: Optional[dict],
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
batch_models: Optional[List[str]] = None,
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||
) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
|
@ -55,6 +56,7 @@ def _get_spend_logs_metadata(
|
|||
error_information=None,
|
||||
proxy_server_request=None,
|
||||
batch_models=None,
|
||||
mcp_tool_call_metadata=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
|
@ -71,6 +73,7 @@ def _get_spend_logs_metadata(
|
|||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
clean_metadata["batch_models"] = batch_models
|
||||
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
||||
return clean_metadata
|
||||
|
||||
|
||||
|
@ -200,6 +203,11 @@ def get_logging_payload( # noqa: PLR0915
|
|||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
mcp_tool_call_metadata=(
|
||||
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
|
|
16
litellm/types/mcp_server/mcp_server_manager.py
Normal file
16
litellm/types/mcp_server/mcp_server_manager.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class MCPInfo(TypedDict, total=False):
|
||||
server_name: str
|
||||
logo_url: Optional[str]
|
||||
|
||||
|
||||
class MCPSSEServer(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@ -1629,6 +1629,33 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
|||
user_api_key_end_user_id: Optional[str]
|
||||
|
||||
|
||||
class StandardLoggingMCPToolCall(TypedDict, total=False):
|
||||
name: str
|
||||
"""
|
||||
Name of the tool to call
|
||||
"""
|
||||
arguments: dict
|
||||
"""
|
||||
Arguments to pass to the tool
|
||||
"""
|
||||
result: dict
|
||||
"""
|
||||
Result of the tool call
|
||||
"""
|
||||
|
||||
mcp_server_name: Optional[str]
|
||||
"""
|
||||
Name of the MCP server that the tool call was made to
|
||||
"""
|
||||
|
||||
mcp_server_logo_url: Optional[str]
|
||||
"""
|
||||
Optional logo URL of the MCP server that the tool call was made to
|
||||
|
||||
(this is to render the logo on the logs page on litellm ui)
|
||||
"""
|
||||
|
||||
|
||||
class StandardBuiltInToolsParams(TypedDict, total=False):
|
||||
"""
|
||||
Standard built-in OpenAItools parameters
|
||||
|
@ -1659,6 +1686,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
|||
requester_ip_address: Optional[str]
|
||||
requester_metadata: Optional[dict]
|
||||
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
|
||||
applied_guardrails: Optional[List[str]]
|
||||
|
||||
|
||||
|
|
|
@ -19,11 +19,11 @@ from mcp.types import Tool as MCPTool
|
|||
|
||||
from litellm.experimental_mcp_client.tools import (
|
||||
_get_function_arguments,
|
||||
_transform_openai_tool_call_to_mcp_tool_call_request,
|
||||
call_mcp_tool,
|
||||
call_openai_tool,
|
||||
load_mcp_tools,
|
||||
transform_mcp_tool_to_openai_tool,
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request,
|
||||
)
|
||||
|
||||
|
||||
|
@ -76,11 +76,11 @@ def test_transform_mcp_tool_to_openai_tool(mock_mcp_tool):
|
|||
}
|
||||
|
||||
|
||||
def test_transform_openai_tool_call_to_mcp_tool_call_request(mock_mcp_tool):
|
||||
def testtransform_openai_tool_call_request_to_mcp_tool_call_request(mock_mcp_tool):
|
||||
openai_tool = {
|
||||
"function": {"name": "test_tool", "arguments": json.dumps({"test": "value"})}
|
||||
}
|
||||
mcp_tool_call_request = _transform_openai_tool_call_to_mcp_tool_call_request(
|
||||
mcp_tool_call_request = transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool
|
||||
)
|
||||
assert mcp_tool_call_request.name == "test_tool"
|
||||
|
|
|
@ -456,7 +456,7 @@ class TestSpendLogsPayload:
|
|||
"model": "gpt-4o",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
|
@ -553,7 +553,7 @@ class TestSpendLogsPayload:
|
|||
"model": "claude-3-7-sonnet-20250219",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.01383,
|
||||
"total_tokens": 2598,
|
||||
|
@ -648,7 +648,7 @@ class TestSpendLogsPayload:
|
|||
"model": "claude-3-7-sonnet-20250219",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.01383,
|
||||
"total_tokens": 2598,
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
"model": "gpt-4o",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"mcp_tool_call_metadata\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
"requester_metadata": null,
|
||||
"user_api_key_end_user_id": null,
|
||||
"prompt_management_metadata": null,
|
||||
"applied_guardrails": []
|
||||
"applied_guardrails": [],
|
||||
"mcp_tool_call_metadata": null
|
||||
},
|
||||
"cache_key": null,
|
||||
"response_cost": 0.00022500000000000002,
|
||||
|
|
|
@ -274,6 +274,7 @@ def validate_redacted_message_span_attributes(span):
|
|||
"metadata.user_api_key_end_user_id",
|
||||
"metadata.user_api_key_user_email",
|
||||
"metadata.applied_guardrails",
|
||||
"metadata.mcp_tool_call_metadata",
|
||||
]
|
||||
|
||||
_all_attributes = set(
|
||||
|
|
35
tests/mcp_tests/test_mcp_server.py
Normal file
35
tests/mcp_tests/test_mcp_server.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Create server parameters for stdio connection
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
MCPServerManager,
|
||||
MCPSSEServer,
|
||||
)
|
||||
|
||||
|
||||
mcp_server_manager = MCPServerManager()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Local only test")
|
||||
async def test_mcp_server_manager():
|
||||
mcp_server_manager.load_servers_from_config(
|
||||
{
|
||||
"zapier_mcp_server": {
|
||||
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
|
||||
}
|
||||
}
|
||||
)
|
||||
tools = await mcp_server_manager.list_tools()
|
||||
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
|
||||
|
||||
result = await mcp_server_manager.call_tool(
|
||||
name="gmail_send_email", arguments={"body": "Test"}
|
||||
)
|
||||
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)
|
|
@ -32,6 +32,8 @@ import GuardrailsPanel from "@/components/guardrails";
|
|||
import TransformRequestPanel from "@/components/transform_request";
|
||||
import { fetchUserModels } from "@/components/create_key_button";
|
||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||
import MCPToolsViewer from "@/components/mcp_tools";
|
||||
|
||||
function getCookie(name: string) {
|
||||
const cookieValue = document.cookie
|
||||
.split("; ")
|
||||
|
@ -347,6 +349,12 @@ export default function CreateKeyPage() {
|
|||
accessToken={accessToken}
|
||||
allTeams={teams as Team[] ?? []}
|
||||
/>
|
||||
) : page == "mcp-tools" ? (
|
||||
<MCPToolsViewer
|
||||
accessToken={accessToken}
|
||||
userRole={userRole}
|
||||
userID={userID}
|
||||
/>
|
||||
) : page == "new_usage" ? (
|
||||
<NewUsagePage
|
||||
userID={userID}
|
||||
|
|
|
@ -20,7 +20,8 @@ import {
|
|||
SafetyOutlined,
|
||||
ExperimentOutlined,
|
||||
ThunderboltOutlined,
|
||||
LockOutlined
|
||||
LockOutlined,
|
||||
ToolOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess } from '../utils/roles';
|
||||
|
||||
|
@ -69,6 +70,7 @@ const menuItems: MenuItem[] = [
|
|||
{ key: "10", page: "budgets", label: "Budgets", icon: <BankOutlined />, roles: all_admin_roles },
|
||||
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
|
||||
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: all_admin_roles },
|
||||
{ key: "18", page: "mcp-tools", label: "MCP Tools", icon: <ToolOutlined />, roles: all_admin_roles },
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import React from 'react';
|
||||
|
||||
const codeString = `import asyncio
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionUserMessageParam
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from litellm.experimental_mcp_client.tools import (
|
||||
transform_mcp_tool_to_openai_tool,
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request,
|
||||
)
|
||||
|
||||
async def main():
|
||||
# Initialize clients
|
||||
client = AsyncOpenAI(
|
||||
api_key="sk-1234",
|
||||
base_url="http://localhost:4000"
|
||||
)
|
||||
|
||||
# Connect to MCP
|
||||
async with sse_client("http://localhost:4000/mcp/") as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
mcp_tools = await session.list_tools()
|
||||
print("List of MCP tools for MCP server:", mcp_tools.tools)
|
||||
|
||||
# Create message
|
||||
messages = [
|
||||
ChatCompletionUserMessageParam(
|
||||
content="Send an email about LiteLLM supporting MCP",
|
||||
role="user"
|
||||
)
|
||||
]
|
||||
|
||||
# Request with tools
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
tools=[transform_mcp_tool_to_openai_tool(tool) for tool in mcp_tools.tools],
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
# Handle tool call
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if tool_call:
|
||||
# Convert format
|
||||
mcp_call = transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=tool_call.model_dump()
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
result = await session.call_tool(
|
||||
name=mcp_call.name,
|
||||
arguments=mcp_call.arguments
|
||||
)
|
||||
|
||||
print("Result:", result)
|
||||
|
||||
# Run it
|
||||
asyncio.run(main())`;
|
||||
|
||||
export const CodeExample: React.FC = () => {
|
||||
return (
|
||||
<div className="bg-white rounded-lg shadow h-full">
|
||||
<div className="border-b px-4 py-3">
|
||||
<h3 className="text-base font-medium text-gray-900">Using MCP Tools</h3>
|
||||
</div>
|
||||
<div className="p-4">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<div className="flex-1">
|
||||
<div className="text-sm font-medium text-gray-700">Python integration</div>
|
||||
</div>
|
||||
<button
|
||||
className="text-xs bg-gray-100 hover:bg-gray-200 text-gray-600 px-2 py-1 rounded-md transition-colors"
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(codeString);
|
||||
}}
|
||||
>
|
||||
Copy
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="overflow-auto rounded-md bg-gray-50 border" style={{ maxHeight: "calc(100vh - 280px)" }}>
|
||||
<pre className="p-3 text-xs font-mono text-gray-800 whitespace-pre overflow-x-auto">
|
||||
{codeString}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
309
ui/litellm-dashboard/src/components/mcp_tools/columns.tsx
Normal file
309
ui/litellm-dashboard/src/components/mcp_tools/columns.tsx
Normal file
|
@ -0,0 +1,309 @@
|
|||
import React from "react";
|
||||
import { ColumnDef } from "@tanstack/react-table";
|
||||
import { MCPTool, InputSchema } from "./types";
|
||||
import { Button } from "@tremor/react"
|
||||
|
||||
export const columns: ColumnDef<MCPTool>[] = [
|
||||
{
|
||||
accessorKey: "mcp_info.server_name",
|
||||
header: "Provider",
|
||||
cell: ({ row }) => {
|
||||
const serverName = row.original.mcp_info.server_name;
|
||||
const logoUrl = row.original.mcp_info.logo_url;
|
||||
|
||||
return (
|
||||
<div className="flex items-center space-x-2">
|
||||
{logoUrl && (
|
||||
<img
|
||||
src={logoUrl}
|
||||
alt={`${serverName} logo`}
|
||||
className="h-5 w-5 object-contain"
|
||||
/>
|
||||
)}
|
||||
<span className="font-medium">{serverName}</span>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
accessorKey: "name",
|
||||
header: "Tool Name",
|
||||
cell: ({ row }) => {
|
||||
const name = row.getValue("name") as string;
|
||||
return (
|
||||
<div>
|
||||
<span className="font-mono text-sm">{name}</span>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
accessorKey: "description",
|
||||
header: "Description",
|
||||
cell: ({ row }) => {
|
||||
const description = row.getValue("description") as string;
|
||||
return (
|
||||
<div className="max-w-md">
|
||||
<span className="text-sm text-gray-700">{description}</span>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "actions",
|
||||
header: "Actions",
|
||||
cell: ({ row }) => {
|
||||
const tool = row.original;
|
||||
|
||||
return (
|
||||
<div className="flex items-center space-x-2">
|
||||
<Button
|
||||
size="xs"
|
||||
variant="light"
|
||||
className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5 text-left overflow-hidden truncate max-w-[200px]"
|
||||
onClick={() => {
|
||||
if (typeof row.original.onToolSelect === 'function') {
|
||||
row.original.onToolSelect(tool);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Test Tool
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
// Tool Panel component to display when a tool is selected
|
||||
export function ToolTestPanel({
|
||||
tool,
|
||||
onSubmit,
|
||||
isLoading,
|
||||
result,
|
||||
error,
|
||||
onClose
|
||||
}: {
|
||||
tool: MCPTool;
|
||||
onSubmit: (args: Record<string, any>) => void;
|
||||
isLoading: boolean;
|
||||
result: any | null;
|
||||
error: Error | null;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const [formState, setFormState] = React.useState<Record<string, any>>({});
|
||||
|
||||
// Create a placeholder schema if we only have the "tool_input_schema" string
|
||||
const schema: InputSchema = React.useMemo(() => {
|
||||
if (typeof tool.inputSchema === 'string') {
|
||||
// Default schema with a single text field
|
||||
return {
|
||||
type: "object",
|
||||
properties: {
|
||||
input: {
|
||||
type: "string",
|
||||
description: "Input for this tool"
|
||||
}
|
||||
},
|
||||
required: ["input"]
|
||||
};
|
||||
}
|
||||
return tool.inputSchema as InputSchema;
|
||||
}, [tool.inputSchema]);
|
||||
|
||||
const handleInputChange = (key: string, value: any) => {
|
||||
setFormState(prev => ({
|
||||
...prev,
|
||||
[key]: value
|
||||
}));
|
||||
};
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
onSubmit(formState);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="bg-white rounded-lg shadow-lg border p-6 max-w-4xl w-full">
|
||||
<div className="flex justify-between items-start mb-4">
|
||||
<div>
|
||||
<h2 className="text-xl font-bold">Test Tool: <span className="font-mono">{tool.name}</span></h2>
|
||||
<p className="text-gray-600">{tool.description}</p>
|
||||
<p className="text-sm text-gray-500 mt-1">Provider: {tool.mcp_info.server_name}</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="p-1 rounded-full hover:bg-gray-200"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<line x1="18" y1="6" x2="6" y2="18"></line>
|
||||
<line x1="6" y1="6" x2="18" y2="18"></line>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
{/* Form Section */}
|
||||
<div className="bg-gray-50 p-4 rounded-lg">
|
||||
<h3 className="font-medium mb-4">Input Parameters</h3>
|
||||
<form onSubmit={handleSubmit}>
|
||||
{typeof tool.inputSchema === 'string' ? (
|
||||
<div className="mb-4">
|
||||
<p className="text-xs text-gray-500 mb-1">This tool uses a dynamic input schema.</p>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-1">
|
||||
Input <span className="text-red-500">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={formState.input || ""}
|
||||
onChange={(e) => handleInputChange("input", e.target.value)}
|
||||
required
|
||||
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
Object.entries(schema.properties).map(([key, prop]) => (
|
||||
<div key={key} className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-1">
|
||||
{key}{" "}
|
||||
{schema.required?.includes(key) && (
|
||||
<span className="text-red-500">*</span>
|
||||
)}
|
||||
</label>
|
||||
{prop.description && (
|
||||
<p className="text-xs text-gray-500 mb-1">{prop.description}</p>
|
||||
)}
|
||||
|
||||
{/* Render appropriate input based on type */}
|
||||
{prop.type === "string" && (
|
||||
<input
|
||||
type="text"
|
||||
value={formState[key] || ""}
|
||||
onChange={(e) => handleInputChange(key, e.target.value)}
|
||||
required={schema.required?.includes(key)}
|
||||
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
{prop.type === "number" && (
|
||||
<input
|
||||
type="number"
|
||||
value={formState[key] || ""}
|
||||
onChange={(e) => handleInputChange(key, parseFloat(e.target.value))}
|
||||
required={schema.required?.includes(key)}
|
||||
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
{prop.type === "boolean" && (
|
||||
<div className="flex items-center">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={formState[key] || false}
|
||||
onChange={(e) => handleInputChange(key, e.target.checked)}
|
||||
className="h-4 w-4 text-blue-600 focus:ring-blue-500 border-gray-300 rounded"
|
||||
/>
|
||||
<span className="ml-2 text-sm text-gray-600">Enable</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
|
||||
<div className="mt-6">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isLoading}
|
||||
className="w-full px-4 py-2 border border-transparent rounded-md shadow-sm text-sm font-medium text-white"
|
||||
>
|
||||
{isLoading ? "Calling..." : "Call Tool"}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
{/* Result Section */}
|
||||
<div className="bg-gray-50 p-4 rounded-lg overflow-auto max-h-[500px]">
|
||||
<h3 className="font-medium mb-4">Result</h3>
|
||||
|
||||
{isLoading && (
|
||||
<div className="flex justify-center items-center py-8">
|
||||
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-blue-700"></div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="bg-red-50 border border-red-200 text-red-800 px-4 py-3 rounded-md">
|
||||
<p className="font-medium">Error</p>
|
||||
<pre className="mt-2 text-xs overflow-auto whitespace-pre-wrap">{error.message}</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result && !isLoading && !error && (
|
||||
<div>
|
||||
{result.map((content: any, idx: number) => (
|
||||
<div key={idx} className="mb-4">
|
||||
{content.type === "text" && (
|
||||
<div className="bg-white border p-3 rounded-md">
|
||||
<p className="whitespace-pre-wrap text-sm">{content.text}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{content.type === "image" && content.url && (
|
||||
<div className="bg-white border p-3 rounded-md">
|
||||
<img src={content.url} alt="Tool result" className="max-w-full h-auto rounded" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{content.type === "embedded_resource" && (
|
||||
<div className="bg-white border p-3 rounded-md">
|
||||
<p className="text-sm font-medium">Embedded Resource</p>
|
||||
<p className="text-xs text-gray-500">Type: {content.resource_type}</p>
|
||||
{content.url && (
|
||||
<a
|
||||
href={content.url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-sm text-blue-600 hover:underline"
|
||||
>
|
||||
View Resource
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className="mt-2">
|
||||
<details className="text-xs">
|
||||
<summary className="cursor-pointer text-gray-500 hover:text-gray-700">Raw JSON Response</summary>
|
||||
<pre className="mt-2 bg-gray-100 p-2 rounded-md overflow-auto max-h-[300px]">
|
||||
{JSON.stringify(result, null, 2)}
|
||||
</pre>
|
||||
</details>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!result && !isLoading && !error && (
|
||||
<div className="text-center py-8 text-gray-500">
|
||||
<p>The result will appear here after you call the tool.</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
172
ui/litellm-dashboard/src/components/mcp_tools/index.tsx
Normal file
172
ui/litellm-dashboard/src/components/mcp_tools/index.tsx
Normal file
|
@ -0,0 +1,172 @@
|
|||
import React, { useState } from 'react';
|
||||
import { useQuery, useMutation } from '@tanstack/react-query';
|
||||
import { DataTable } from '../view_logs/table';
|
||||
import { columns, ToolTestPanel } from './columns';
|
||||
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from './types';
|
||||
import { listMCPTools, callMCPTool } from '../networking';
|
||||
|
||||
// Wrapper to handle the type mismatch between MCPTool and DataTable's expected type
|
||||
function DataTableWrapper({
|
||||
columns,
|
||||
data,
|
||||
isLoading,
|
||||
}: {
|
||||
columns: any;
|
||||
data: MCPTool[];
|
||||
isLoading: boolean;
|
||||
}) {
|
||||
// Create a dummy renderSubComponent and getRowCanExpand function
|
||||
const renderSubComponent = () => <div />;
|
||||
const getRowCanExpand = () => false;
|
||||
|
||||
return (
|
||||
<DataTable
|
||||
columns={columns as any}
|
||||
data={data as any}
|
||||
isLoading={isLoading}
|
||||
renderSubComponent={renderSubComponent}
|
||||
getRowCanExpand={getRowCanExpand}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default function MCPToolsViewer({
|
||||
accessToken,
|
||||
userRole,
|
||||
userID,
|
||||
}: MCPToolsViewerProps) {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
|
||||
const [toolError, setToolError] = useState<Error | null>(null);
|
||||
|
||||
// Query to fetch MCP tools
|
||||
const { data: mcpTools, isLoading: isLoadingTools } = useQuery({
|
||||
queryKey: ['mcpTools'],
|
||||
queryFn: () => {
|
||||
if (!accessToken) throw new Error('Access Token required');
|
||||
return listMCPTools(accessToken);
|
||||
},
|
||||
enabled: !!accessToken,
|
||||
});
|
||||
|
||||
// Mutation for calling a tool
|
||||
const { mutate: executeTool, isPending: isCallingTool } = useMutation({
|
||||
mutationFn: (args: { tool: MCPTool; arguments: Record<string, any> }) => {
|
||||
if (!accessToken) throw new Error('Access Token required');
|
||||
return callMCPTool(
|
||||
accessToken,
|
||||
args.tool.name,
|
||||
args.arguments
|
||||
);
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
setToolResult(data);
|
||||
setToolError(null);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
setToolError(error);
|
||||
setToolResult(null);
|
||||
},
|
||||
});
|
||||
|
||||
// Add onToolSelect handler to each tool
|
||||
const toolsData = React.useMemo(() => {
|
||||
if (!mcpTools) return [];
|
||||
|
||||
return mcpTools.map((tool: MCPTool) => ({
|
||||
...tool,
|
||||
onToolSelect: (tool: MCPTool) => {
|
||||
setSelectedTool(tool);
|
||||
setToolResult(null);
|
||||
setToolError(null);
|
||||
}
|
||||
}));
|
||||
}, [mcpTools]);
|
||||
|
||||
// Filter tools based on search term
|
||||
const filteredTools = React.useMemo(() => {
|
||||
return toolsData.filter((tool: MCPTool) => {
|
||||
const searchLower = searchTerm.toLowerCase();
|
||||
return (
|
||||
tool.name.toLowerCase().includes(searchLower) ||
|
||||
tool.description.toLowerCase().includes(searchLower) ||
|
||||
tool.mcp_info.server_name.toLowerCase().includes(searchLower)
|
||||
);
|
||||
});
|
||||
}, [toolsData, searchTerm]);
|
||||
|
||||
// Handle tool call submission
|
||||
const handleToolSubmit = (args: Record<string, any>) => {
|
||||
if (!selectedTool) return;
|
||||
|
||||
executeTool({
|
||||
tool: selectedTool,
|
||||
arguments: args,
|
||||
});
|
||||
};
|
||||
|
||||
if (!accessToken || !userRole || !userID) {
|
||||
return <div className="p-6 text-center text-gray-500">Missing required authentication parameters.</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full p-6">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h1 className="text-xl font-semibold">MCP Tools</h1>
|
||||
</div>
|
||||
|
||||
<div className="bg-white rounded-lg shadow">
|
||||
<div className="border-b px-6 py-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="relative w-64">
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search tools..."
|
||||
className="w-full px-3 py-2 pl-8 border rounded-md text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
/>
|
||||
<svg
|
||||
className="absolute left-2.5 top-2.5 h-4 w-4 text-gray-500"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<div className="text-sm text-gray-500">
|
||||
{filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""} available
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<DataTableWrapper
|
||||
columns={columns}
|
||||
data={filteredTools}
|
||||
isLoading={isLoadingTools}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Tool Test Panel - Show when a tool is selected */}
|
||||
{selectedTool && (
|
||||
<div className="fixed inset-0 bg-gray-800 bg-opacity-75 flex items-center justify-center z-50 p-4">
|
||||
<ToolTestPanel
|
||||
tool={selectedTool}
|
||||
onSubmit={handleToolSubmit}
|
||||
isLoading={isCallingTool}
|
||||
result={toolResult}
|
||||
error={toolError}
|
||||
onClose={() => setSelectedTool(null)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
71
ui/litellm-dashboard/src/components/mcp_tools/types.tsx
Normal file
71
ui/litellm-dashboard/src/components/mcp_tools/types.tsx
Normal file
|
@ -0,0 +1,71 @@
|
|||
// Define the structure for tool input schema properties
|
||||
export interface InputSchemaProperty {
|
||||
type: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
// Define the structure for the input schema of a tool
|
||||
export interface InputSchema {
|
||||
type: "object";
|
||||
properties: Record<string, InputSchemaProperty>;
|
||||
required?: string[];
|
||||
}
|
||||
|
||||
// Define MCP provider info
|
||||
export interface MCPInfo {
|
||||
server_name: string;
|
||||
logo_url?: string;
|
||||
}
|
||||
|
||||
// Define the structure for a single MCP tool
|
||||
export interface MCPTool {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema: InputSchema | string; // API returns string "tool_input_schema" or the actual schema
|
||||
mcp_info: MCPInfo;
|
||||
// Function to select a tool (added in the component)
|
||||
onToolSelect?: (tool: MCPTool) => void;
|
||||
}
|
||||
|
||||
// Define the response structure for the listMCPTools endpoint - now a flat array
|
||||
export type ListMCPToolsResponse = MCPTool[];
|
||||
|
||||
// Define the argument structure for calling an MCP tool
|
||||
export interface CallMCPToolArgs {
|
||||
name: string;
|
||||
arguments: Record<string, any> | null;
|
||||
server_name?: string; // Now using server_name from mcp_info
|
||||
}
|
||||
|
||||
// Define the possible content types in the response
|
||||
export interface MCPTextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
annotations?: any;
|
||||
}
|
||||
|
||||
export interface MCPImageContent {
|
||||
type: "image";
|
||||
url?: string;
|
||||
data?: string;
|
||||
}
|
||||
|
||||
export interface MCPEmbeddedResource {
|
||||
type: "embedded_resource";
|
||||
resource_type?: string;
|
||||
url?: string;
|
||||
data?: any;
|
||||
}
|
||||
|
||||
// Define the union type for the content array in the response
|
||||
export type MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource;
|
||||
|
||||
// Define the response structure for the callMCPTool endpoint
|
||||
export type CallMCPToolResponse = MCPContent[];
|
||||
|
||||
// Props for the main component
|
||||
export interface MCPToolsViewerProps {
|
||||
accessToken: string | null;
|
||||
userRole: string | null;
|
||||
userID: string | null;
|
||||
}
|
|
@ -4084,3 +4084,73 @@ export const updateInternalUserSettings = async (accessToken: string, settings:
|
|||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const listMCPTools = async (accessToken: string) => {
|
||||
try {
|
||||
// Construct base URL
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/mcp/tools/list`
|
||||
: `/mcp/tools/list`;
|
||||
|
||||
console.log("Fetching MCP tools from:", url);
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
handleError(errorData);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Fetched MCP tools:", data);
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch MCP tools:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const callMCPTool = async (accessToken: string, toolName: string, toolArguments: Record<string, any>) => {
|
||||
try {
|
||||
// Construct base URL
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/mcp/tools/call`
|
||||
: `/mcp/tools/call`;
|
||||
|
||||
console.log("Calling MCP tool:", toolName, "with arguments:", toolArguments);
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: toolName,
|
||||
arguments: toolArguments,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
handleError(errorData);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("MCP tool call response:", data);
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error("Failed to call MCP tool:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
|
@ -8,6 +8,19 @@ import { Tooltip } from "antd";
|
|||
import { TimeCell } from "./time_cell";
|
||||
import { Button } from "@tremor/react";
|
||||
|
||||
// Helper to get the appropriate logo URL
|
||||
const getLogoUrl = (
|
||||
row: LogEntry,
|
||||
provider: string
|
||||
) => {
|
||||
// Check if mcp_tool_call_metadata exists and contains mcp_server_logo_url
|
||||
if (row.metadata?.mcp_tool_call_metadata?.mcp_server_logo_url) {
|
||||
return row.metadata.mcp_tool_call_metadata.mcp_server_logo_url;
|
||||
}
|
||||
// Fall back to default provider logo
|
||||
return provider ? getProviderLogoAndName(provider).logo : '';
|
||||
};
|
||||
|
||||
export type LogEntry = {
|
||||
request_id: string;
|
||||
api_key: string;
|
||||
|
@ -177,7 +190,7 @@ export const columns: ColumnDef<LogEntry>[] = [
|
|||
<div className="flex items-center space-x-2">
|
||||
{provider && (
|
||||
<img
|
||||
src={getProviderLogoAndName(provider).logo}
|
||||
src={getLogoUrl(row, provider)}
|
||||
alt=""
|
||||
className="w-4 h-4"
|
||||
onError={(e) => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue