Merge pull request #9642 from BerriAI/litellm_mcp_improvements_expose_sse_urls

[Feat] - MCP improvements, add support for using SSE MCP servers
This commit is contained in:
Ishaan Jaff 2025-03-29 19:37:57 -07:00
parent 0865e52db3
commit 5daf40ce24
25 changed files with 1210 additions and 29 deletions

View file

@ -414,6 +414,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
MCP_TOOL_NAME_PREFIX = "mcp_tool"
########################### LiteLLM Proxy Specific Constants ###########################
########################################################################################

View file

@ -1,5 +1,5 @@
import json
from typing import List, Literal, Union
from typing import Dict, List, Literal, Union
from mcp import ClientSession
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
@ -76,8 +76,8 @@ def _get_function_arguments(function: FunctionDefinition) -> dict:
return arguments if isinstance(arguments, dict) else {}
def _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool: ChatCompletionMessageToolCall,
def transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
@ -100,8 +100,10 @@ async def call_openai_tool(
Returns:
The result of the MCP tool call.
"""
mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request(
openai_tool=openai_tool,
mcp_tool_call_request_params = (
transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=openai_tool,
)
)
return await call_mcp_tool(
session=session,

View file

@ -67,6 +67,7 @@ from litellm.types.utils import (
StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders,
StandardLoggingHiddenParams,
StandardLoggingMCPToolCall,
StandardLoggingMetadata,
StandardLoggingModelCostFailureDebugInformation,
StandardLoggingModelInformation,
@ -1099,7 +1100,7 @@ class Logging(LiteLLMLoggingBaseClass):
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
elif isinstance(result, dict): # pass-through endpoints
elif isinstance(result, dict) or isinstance(result, list):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
@ -3114,6 +3115,7 @@ class StandardLoggingPayloadSetup:
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -3160,6 +3162,7 @@ class StandardLoggingPayloadSetup:
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
mcp_tool_call_metadata=mcp_tool_call_metadata,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3486,6 +3489,7 @@ def get_standard_logging_object_payload(
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
)
_request_body = proxy_server_request.get("body", {})
@ -3626,6 +3630,7 @@ def get_standard_logging_metadata(
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
mcp_tool_call_metadata=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys

View file

@ -0,0 +1,153 @@
"""
MCP Client Manager
This class is responsible for managing MCP SSE clients.
This is a Proxy
"""
import asyncio
import json
from typing import Any, Dict, List, Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_logger
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
class MCPServerManager:
def __init__(self):
self.mcp_servers: List[MCPSSEServer] = []
"""
eg.
[
{
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
},
{
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
]
"""
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
"""
{
"gmail_send_email": "zapier_mcp_server",
}
"""
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
self.mcp_servers.append(
MCPSSEServer(
name=server_name,
url=server_config["url"],
mcp_info=mcp_info,
)
)
verbose_logger.debug(
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
)
self.initialize_tool_name_to_mcp_server_name_mapping()
async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.
Returns:
List[MCPTool]: Combined list of tools from all servers
"""
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)
return list_tools_result
async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.
Args:
server (MCPSSEServer): The server to query tools from
Returns:
List[MCPTool]: List of tools available on the server
"""
verbose_logger.debug(f"Connecting to url: {server.url}")
async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
return tools_result.tools
def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
"""
try:
if asyncio.get_running_loop():
asyncio.create_task(
self._initialize_tool_name_to_mcp_server_name_mapping()
)
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
)
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
async def call_tool(self, name: str, arguments: Dict[str, Any]):
"""
Call a tool with the given name and arguments
"""
mcp_server = self._get_mcp_server_from_tool_name(name)
if mcp_server is None:
raise ValueError(f"Tool {name} not found")
async with sse_client(url=mcp_server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
return await session.call_tool(name, arguments)
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
for server in self.mcp_servers:
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None
global_mcp_server_manager: MCPServerManager = MCPServerManager()

View file

@ -3,14 +3,21 @@ LiteLLM MCP Server Routes
"""
import asyncio
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from anyio import BrokenResourceError
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from pydantic import ConfigDict, ValidationError
from litellm._logging import verbose_logger
from litellm.constants import MCP_TOOL_NAME_PREFIX
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
from litellm.types.utils import StandardLoggingMCPToolCall
from litellm.utils import client
# Check if MCP is available
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
@ -36,9 +43,23 @@ if MCP_AVAILABLE:
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool
from .mcp_server_manager import global_mcp_server_manager
from .sse_transport import SseServerTransport
from .tool_registry import global_mcp_tool_registry
######################################################
############ MCP Tools List REST API Response Object #
# Defined here because we don't want to add `mcp` as a
# required dependency for `litellm` pip package
######################################################
class ListMCPToolsRestAPIResponseObject(MCPTool):
"""
Object returned by the /tools/list REST API route.
"""
mcp_info: Optional[MCPInfo] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
########################################################
############ Initialize the MCP Server #################
########################################################
@ -52,9 +73,14 @@ if MCP_AVAILABLE:
########################################################
############### MCP Server Routes #######################
########################################################
@server.list_tools()
async def list_tools() -> list[MCPTool]:
"""
List all available tools
"""
return await _list_mcp_tools()
async def _list_mcp_tools() -> List[MCPTool]:
"""
List all available tools
"""
@ -67,24 +93,116 @@ if MCP_AVAILABLE:
inputSchema=tool.input_schema,
)
)
verbose_logger.debug(
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
)
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
if sse_tools is not None:
tools.extend(sse_tools)
return tools
@server.call_tool()
async def handle_call_tool(
async def mcp_server_tool_call(
name: str, arguments: Dict[str, Any] | None
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Call a specific tool with the provided arguments
Args:
name (str): Name of the tool to call
arguments (Dict[str, Any] | None): Arguments to pass to the tool
Returns:
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
Raises:
HTTPException: If tool not found or arguments missing
"""
# Validate arguments
response = await call_mcp_tool(
name=name,
arguments=arguments,
)
return response
@client
async def call_mcp_tool(
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Call a specific tool with the provided arguments
"""
tool = global_mcp_tool_registry.get_tool(name)
if not tool:
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
if arguments is None:
raise HTTPException(
status_code=400, detail="Request arguments are required"
)
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
_get_standard_logging_mcp_tool_call(
name=name,
arguments=arguments,
)
)
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
"litellm_logging_obj", None
)
if litellm_logging_obj:
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
standard_logging_mcp_tool_call
)
litellm_logging_obj.model_call_details["model"] = (
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
)
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
standard_logging_mcp_tool_call.get("mcp_server_name")
)
# Try managed server tool first
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
return await _handle_managed_mcp_tool(name, arguments)
# Fall back to local tool registry
return await _handle_local_mcp_tool(name, arguments)
def _get_standard_logging_mcp_tool_call(
name: str,
arguments: Dict[str, Any],
) -> StandardLoggingMCPToolCall:
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
if mcp_server:
mcp_info = mcp_server.mcp_info or {}
return StandardLoggingMCPToolCall(
name=name,
arguments=arguments,
mcp_server_name=mcp_info.get("server_name"),
mcp_server_logo_url=mcp_info.get("logo_url"),
)
else:
return StandardLoggingMCPToolCall(
name=name,
arguments=arguments,
)
async def _handle_managed_mcp_tool(
name: str, arguments: Dict[str, Any]
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""Handle tool execution for managed server tools"""
call_tool_result = await global_mcp_server_manager.call_tool(
name=name,
arguments=arguments,
)
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
return call_tool_result.content
async def _handle_local_mcp_tool(
name: str, arguments: Dict[str, Any]
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""Handle tool execution for local registry tools"""
tool = global_mcp_tool_registry.get_tool(name)
if not tool:
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
try:
result = tool.handler(**arguments)
return [MCPTextContent(text=str(result), type="text")]
@ -113,6 +231,74 @@ if MCP_AVAILABLE:
await sse.handle_post_message(request.scope, request.receive, request._send)
await request.close()
########################################################
############ MCP Server REST API Routes #################
########################################################
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
"""
List all available tools with information about the server they belong to.
Example response:
Tools:
[
{
"name": "create_zap",
"description": "Create a new zap",
"inputSchema": "tool_input_schema",
"mcp_info": {
"server_name": "zapier",
"logo_url": "https://www.zapier.com/logo.png",
}
},
{
"name": "fetch_data",
"description": "Fetch data from a URL",
"inputSchema": "tool_input_schema",
"mcp_info": {
"server_name": "fetch",
"logo_url": "https://www.fetch.com/logo.png",
}
}
]
"""
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
for server in global_mcp_server_manager.mcp_servers:
try:
tools = await global_mcp_server_manager._get_tools_from_server(server)
for tool in tools:
list_tools_result.append(
ListMCPToolsRestAPIResponseObject(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
mcp_info=server.mcp_info,
)
)
except Exception as e:
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
continue
return list_tools_result
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
async def call_tool_rest_api(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
REST API to call a specific MCP tool with the provided arguments
"""
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config
data = await request.json()
data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
)
return await call_mcp_tool(**data)
options = InitializationOptions(
server_name="litellm-mcp-server",
server_version="0.1.0",

View file

@ -27,6 +27,7 @@ from litellm.types.utils import (
ModelResponse,
ProviderField,
StandardCallbackDynamicParams,
StandardLoggingMCPToolCall,
StandardLoggingPayloadErrorInformation,
StandardLoggingPayloadStatus,
StandardPassThroughResponseObject,
@ -1913,6 +1914,7 @@ class SpendLogsMetadata(TypedDict):
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
applied_guardrails: Optional[List[str]]
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
status: StandardLoggingPayloadStatus
proxy_server_request: Optional[str]
batch_models: Optional[List[str]]

View file

@ -1,9 +1,7 @@
model_list:
- model_name: fake-openai-endpoint
- model_name: gpt-4o
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
allow_requests_on_db_unavailable: True

View file

@ -2165,6 +2165,14 @@ class ProxyConfig:
if mcp_tools_config:
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
mcp_servers_config = config.get("mcp_servers", None)
if mcp_servers_config:
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
global_mcp_server_manager.load_servers_from_config(mcp_servers_config)
## CREDENTIALS
credential_list_dict = self.load_credential_list(config=config)
litellm.credential_list = credential_list_dict

View file

@ -13,7 +13,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import StandardLoggingPayload
from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
@ -38,6 +38,7 @@ def _get_spend_logs_metadata(
metadata: Optional[dict],
applied_guardrails: Optional[List[str]] = None,
batch_models: Optional[List[str]] = None,
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
@ -55,6 +56,7 @@ def _get_spend_logs_metadata(
error_information=None,
proxy_server_request=None,
batch_models=None,
mcp_tool_call_metadata=None,
)
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
@ -71,6 +73,7 @@ def _get_spend_logs_metadata(
)
clean_metadata["applied_guardrails"] = applied_guardrails
clean_metadata["batch_models"] = batch_models
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
return clean_metadata
@ -200,6 +203,11 @@ def get_logging_payload( # noqa: PLR0915
if standard_logging_payload is not None
else None
),
mcp_tool_call_metadata=(
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
if standard_logging_payload is not None
else None
),
)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]

View file

@ -0,0 +1,16 @@
from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel, ConfigDict
from typing_extensions import TypedDict
class MCPInfo(TypedDict, total=False):
server_name: str
logo_url: Optional[str]
class MCPSSEServer(BaseModel):
name: str
url: str
mcp_info: Optional[MCPInfo] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View file

@ -1629,6 +1629,33 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_end_user_id: Optional[str]
class StandardLoggingMCPToolCall(TypedDict, total=False):
name: str
"""
Name of the tool to call
"""
arguments: dict
"""
Arguments to pass to the tool
"""
result: dict
"""
Result of the tool call
"""
mcp_server_name: Optional[str]
"""
Name of the MCP server that the tool call was made to
"""
mcp_server_logo_url: Optional[str]
"""
Optional logo URL of the MCP server that the tool call was made to
(this is to render the logo on the logs page on litellm ui)
"""
class StandardBuiltInToolsParams(TypedDict, total=False):
"""
Standard built-in OpenAItools parameters
@ -1659,6 +1686,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
applied_guardrails: Optional[List[str]]

View file

@ -19,11 +19,11 @@ from mcp.types import Tool as MCPTool
from litellm.experimental_mcp_client.tools import (
_get_function_arguments,
_transform_openai_tool_call_to_mcp_tool_call_request,
call_mcp_tool,
call_openai_tool,
load_mcp_tools,
transform_mcp_tool_to_openai_tool,
transform_openai_tool_call_request_to_mcp_tool_call_request,
)
@ -76,11 +76,11 @@ def test_transform_mcp_tool_to_openai_tool(mock_mcp_tool):
}
def test_transform_openai_tool_call_to_mcp_tool_call_request(mock_mcp_tool):
def testtransform_openai_tool_call_request_to_mcp_tool_call_request(mock_mcp_tool):
openai_tool = {
"function": {"name": "test_tool", "arguments": json.dumps({"test": "value"})}
}
mcp_tool_call_request = _transform_openai_tool_call_to_mcp_tool_call_request(
mcp_tool_call_request = transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool
)
assert mcp_tool_call_request.name == "test_tool"

View file

@ -456,7 +456,7 @@ class TestSpendLogsPayload:
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,
@ -553,7 +553,7 @@ class TestSpendLogsPayload:
"model": "claude-3-7-sonnet-20250219",
"user": "",
"team_id": "",
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
"cache_key": "Cache OFF",
"spend": 0.01383,
"total_tokens": 2598,
@ -648,7 +648,7 @@ class TestSpendLogsPayload:
"model": "claude-3-7-sonnet-20250219",
"user": "",
"team_id": "",
"metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}',
"cache_key": "Cache OFF",
"spend": 0.01383,
"total_tokens": 2598,

View file

@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"mcp_tool_call_metadata\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,

View file

@ -25,7 +25,8 @@
"requester_metadata": null,
"user_api_key_end_user_id": null,
"prompt_management_metadata": null,
"applied_guardrails": []
"applied_guardrails": [],
"mcp_tool_call_metadata": null
},
"cache_key": null,
"response_cost": 0.00022500000000000002,

View file

@ -274,6 +274,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_end_user_id",
"metadata.user_api_key_user_email",
"metadata.applied_guardrails",
"metadata.mcp_tool_call_metadata",
]
_all_attributes = set(

View file

@ -0,0 +1,35 @@
# Create server parameters for stdio connection
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
MCPSSEServer,
)
mcp_server_manager = MCPServerManager()
@pytest.mark.asyncio
@pytest.mark.skip(reason="Local only test")
async def test_mcp_server_manager():
mcp_server_manager.load_servers_from_config(
{
"zapier_mcp_server": {
"url": os.environ.get("ZAPIER_MCP_SERVER_URL"),
}
}
)
tools = await mcp_server_manager.list_tools()
print("TOOLS FROM MCP SERVER MANAGER== ", tools)
result = await mcp_server_manager.call_tool(
name="gmail_send_email", arguments={"body": "Test"}
)
print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result)

View file

@ -32,6 +32,8 @@ import GuardrailsPanel from "@/components/guardrails";
import TransformRequestPanel from "@/components/transform_request";
import { fetchUserModels } from "@/components/create_key_button";
import { fetchTeams } from "@/components/common_components/fetch_teams";
import MCPToolsViewer from "@/components/mcp_tools";
function getCookie(name: string) {
const cookieValue = document.cookie
.split("; ")
@ -347,6 +349,12 @@ export default function CreateKeyPage() {
accessToken={accessToken}
allTeams={teams as Team[] ?? []}
/>
) : page == "mcp-tools" ? (
<MCPToolsViewer
accessToken={accessToken}
userRole={userRole}
userID={userID}
/>
) : page == "new_usage" ? (
<NewUsagePage
userID={userID}

View file

@ -20,7 +20,8 @@ import {
SafetyOutlined,
ExperimentOutlined,
ThunderboltOutlined,
LockOutlined
LockOutlined,
ToolOutlined,
} from '@ant-design/icons';
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess } from '../utils/roles';
@ -69,6 +70,7 @@ const menuItems: MenuItem[] = [
{ key: "10", page: "budgets", label: "Budgets", icon: <BankOutlined />, roles: all_admin_roles },
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: all_admin_roles },
{ key: "18", page: "mcp-tools", label: "MCP Tools", icon: <ToolOutlined />, roles: all_admin_roles },
]
},
{

View file

@ -0,0 +1,92 @@
import React from 'react';
const codeString = `import asyncio
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionUserMessageParam
from mcp import ClientSession
from mcp.client.sse import sse_client
from litellm.experimental_mcp_client.tools import (
transform_mcp_tool_to_openai_tool,
transform_openai_tool_call_request_to_mcp_tool_call_request,
)
async def main():
# Initialize clients
client = AsyncOpenAI(
api_key="sk-1234",
base_url="http://localhost:4000"
)
# Connect to MCP
async with sse_client("http://localhost:4000/mcp/") as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
mcp_tools = await session.list_tools()
print("List of MCP tools for MCP server:", mcp_tools.tools)
# Create message
messages = [
ChatCompletionUserMessageParam(
content="Send an email about LiteLLM supporting MCP",
role="user"
)
]
# Request with tools
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=[transform_mcp_tool_to_openai_tool(tool) for tool in mcp_tools.tools],
tool_choice="auto"
)
# Handle tool call
if response.choices[0].message.tool_calls:
tool_call = response.choices[0].message.tool_calls[0]
if tool_call:
# Convert format
mcp_call = transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=tool_call.model_dump()
)
# Execute tool
result = await session.call_tool(
name=mcp_call.name,
arguments=mcp_call.arguments
)
print("Result:", result)
# Run it
asyncio.run(main())`;
export const CodeExample: React.FC = () => {
return (
<div className="bg-white rounded-lg shadow h-full">
<div className="border-b px-4 py-3">
<h3 className="text-base font-medium text-gray-900">Using MCP Tools</h3>
</div>
<div className="p-4">
<div className="flex items-center gap-2 mb-2">
<div className="flex-1">
<div className="text-sm font-medium text-gray-700">Python integration</div>
</div>
<button
className="text-xs bg-gray-100 hover:bg-gray-200 text-gray-600 px-2 py-1 rounded-md transition-colors"
onClick={() => {
navigator.clipboard.writeText(codeString);
}}
>
Copy
</button>
</div>
<div className="overflow-auto rounded-md bg-gray-50 border" style={{ maxHeight: "calc(100vh - 280px)" }}>
<pre className="p-3 text-xs font-mono text-gray-800 whitespace-pre overflow-x-auto">
{codeString}
</pre>
</div>
</div>
</div>
);
};

View file

@ -0,0 +1,309 @@
import React from "react";
import { ColumnDef } from "@tanstack/react-table";
import { MCPTool, InputSchema } from "./types";
import { Button } from "@tremor/react"
export const columns: ColumnDef<MCPTool>[] = [
{
accessorKey: "mcp_info.server_name",
header: "Provider",
cell: ({ row }) => {
const serverName = row.original.mcp_info.server_name;
const logoUrl = row.original.mcp_info.logo_url;
return (
<div className="flex items-center space-x-2">
{logoUrl && (
<img
src={logoUrl}
alt={`${serverName} logo`}
className="h-5 w-5 object-contain"
/>
)}
<span className="font-medium">{serverName}</span>
</div>
);
},
},
{
accessorKey: "name",
header: "Tool Name",
cell: ({ row }) => {
const name = row.getValue("name") as string;
return (
<div>
<span className="font-mono text-sm">{name}</span>
</div>
);
},
},
{
accessorKey: "description",
header: "Description",
cell: ({ row }) => {
const description = row.getValue("description") as string;
return (
<div className="max-w-md">
<span className="text-sm text-gray-700">{description}</span>
</div>
);
},
},
{
id: "actions",
header: "Actions",
cell: ({ row }) => {
const tool = row.original;
return (
<div className="flex items-center space-x-2">
<Button
size="xs"
variant="light"
className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5 text-left overflow-hidden truncate max-w-[200px]"
onClick={() => {
if (typeof row.original.onToolSelect === 'function') {
row.original.onToolSelect(tool);
}
}}
>
Test Tool
</Button>
</div>
);
},
},
];
// Tool Panel component to display when a tool is selected
export function ToolTestPanel({
tool,
onSubmit,
isLoading,
result,
error,
onClose
}: {
tool: MCPTool;
onSubmit: (args: Record<string, any>) => void;
isLoading: boolean;
result: any | null;
error: Error | null;
onClose: () => void;
}) {
const [formState, setFormState] = React.useState<Record<string, any>>({});
// Create a placeholder schema if we only have the "tool_input_schema" string
const schema: InputSchema = React.useMemo(() => {
if (typeof tool.inputSchema === 'string') {
// Default schema with a single text field
return {
type: "object",
properties: {
input: {
type: "string",
description: "Input for this tool"
}
},
required: ["input"]
};
}
return tool.inputSchema as InputSchema;
}, [tool.inputSchema]);
const handleInputChange = (key: string, value: any) => {
setFormState(prev => ({
...prev,
[key]: value
}));
};
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
onSubmit(formState);
};
return (
<div className="bg-white rounded-lg shadow-lg border p-6 max-w-4xl w-full">
<div className="flex justify-between items-start mb-4">
<div>
<h2 className="text-xl font-bold">Test Tool: <span className="font-mono">{tool.name}</span></h2>
<p className="text-gray-600">{tool.description}</p>
<p className="text-sm text-gray-500 mt-1">Provider: {tool.mcp_info.server_name}</p>
</div>
<button
onClick={onClose}
className="p-1 rounded-full hover:bg-gray-200"
>
<svg
xmlns="http://www.w3.org/2000/svg"
width="20"
height="20"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<line x1="18" y1="6" x2="6" y2="18"></line>
<line x1="6" y1="6" x2="18" y2="18"></line>
</svg>
</button>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
{/* Form Section */}
<div className="bg-gray-50 p-4 rounded-lg">
<h3 className="font-medium mb-4">Input Parameters</h3>
<form onSubmit={handleSubmit}>
{typeof tool.inputSchema === 'string' ? (
<div className="mb-4">
<p className="text-xs text-gray-500 mb-1">This tool uses a dynamic input schema.</p>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-1">
Input <span className="text-red-500">*</span>
</label>
<input
type="text"
value={formState.input || ""}
onChange={(e) => handleInputChange("input", e.target.value)}
required
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
/>
</div>
</div>
) : (
Object.entries(schema.properties).map(([key, prop]) => (
<div key={key} className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-1">
{key}{" "}
{schema.required?.includes(key) && (
<span className="text-red-500">*</span>
)}
</label>
{prop.description && (
<p className="text-xs text-gray-500 mb-1">{prop.description}</p>
)}
{/* Render appropriate input based on type */}
{prop.type === "string" && (
<input
type="text"
value={formState[key] || ""}
onChange={(e) => handleInputChange(key, e.target.value)}
required={schema.required?.includes(key)}
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
/>
)}
{prop.type === "number" && (
<input
type="number"
value={formState[key] || ""}
onChange={(e) => handleInputChange(key, parseFloat(e.target.value))}
required={schema.required?.includes(key)}
className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm"
/>
)}
{prop.type === "boolean" && (
<div className="flex items-center">
<input
type="checkbox"
checked={formState[key] || false}
onChange={(e) => handleInputChange(key, e.target.checked)}
className="h-4 w-4 text-blue-600 focus:ring-blue-500 border-gray-300 rounded"
/>
<span className="ml-2 text-sm text-gray-600">Enable</span>
</div>
)}
</div>
))
)}
<div className="mt-6">
<Button
type="submit"
disabled={isLoading}
className="w-full px-4 py-2 border border-transparent rounded-md shadow-sm text-sm font-medium text-white"
>
{isLoading ? "Calling..." : "Call Tool"}
</Button>
</div>
</form>
</div>
{/* Result Section */}
<div className="bg-gray-50 p-4 rounded-lg overflow-auto max-h-[500px]">
<h3 className="font-medium mb-4">Result</h3>
{isLoading && (
<div className="flex justify-center items-center py-8">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-blue-700"></div>
</div>
)}
{error && (
<div className="bg-red-50 border border-red-200 text-red-800 px-4 py-3 rounded-md">
<p className="font-medium">Error</p>
<pre className="mt-2 text-xs overflow-auto whitespace-pre-wrap">{error.message}</pre>
</div>
)}
{result && !isLoading && !error && (
<div>
{result.map((content: any, idx: number) => (
<div key={idx} className="mb-4">
{content.type === "text" && (
<div className="bg-white border p-3 rounded-md">
<p className="whitespace-pre-wrap text-sm">{content.text}</p>
</div>
)}
{content.type === "image" && content.url && (
<div className="bg-white border p-3 rounded-md">
<img src={content.url} alt="Tool result" className="max-w-full h-auto rounded" />
</div>
)}
{content.type === "embedded_resource" && (
<div className="bg-white border p-3 rounded-md">
<p className="text-sm font-medium">Embedded Resource</p>
<p className="text-xs text-gray-500">Type: {content.resource_type}</p>
{content.url && (
<a
href={content.url}
target="_blank"
rel="noopener noreferrer"
className="text-sm text-blue-600 hover:underline"
>
View Resource
</a>
)}
</div>
)}
</div>
))}
<div className="mt-2">
<details className="text-xs">
<summary className="cursor-pointer text-gray-500 hover:text-gray-700">Raw JSON Response</summary>
<pre className="mt-2 bg-gray-100 p-2 rounded-md overflow-auto max-h-[300px]">
{JSON.stringify(result, null, 2)}
</pre>
</details>
</div>
</div>
)}
{!result && !isLoading && !error && (
<div className="text-center py-8 text-gray-500">
<p>The result will appear here after you call the tool.</p>
</div>
)}
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,172 @@
import React, { useState } from 'react';
import { useQuery, useMutation } from '@tanstack/react-query';
import { DataTable } from '../view_logs/table';
import { columns, ToolTestPanel } from './columns';
import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from './types';
import { listMCPTools, callMCPTool } from '../networking';
// Wrapper to handle the type mismatch between MCPTool and DataTable's expected type
function DataTableWrapper({
columns,
data,
isLoading,
}: {
columns: any;
data: MCPTool[];
isLoading: boolean;
}) {
// Create a dummy renderSubComponent and getRowCanExpand function
const renderSubComponent = () => <div />;
const getRowCanExpand = () => false;
return (
<DataTable
columns={columns as any}
data={data as any}
isLoading={isLoading}
renderSubComponent={renderSubComponent}
getRowCanExpand={getRowCanExpand}
/>
);
}
export default function MCPToolsViewer({
accessToken,
userRole,
userID,
}: MCPToolsViewerProps) {
const [searchTerm, setSearchTerm] = useState('');
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
const [toolResult, setToolResult] = useState<CallMCPToolResponse | null>(null);
const [toolError, setToolError] = useState<Error | null>(null);
// Query to fetch MCP tools
const { data: mcpTools, isLoading: isLoadingTools } = useQuery({
queryKey: ['mcpTools'],
queryFn: () => {
if (!accessToken) throw new Error('Access Token required');
return listMCPTools(accessToken);
},
enabled: !!accessToken,
});
// Mutation for calling a tool
const { mutate: executeTool, isPending: isCallingTool } = useMutation({
mutationFn: (args: { tool: MCPTool; arguments: Record<string, any> }) => {
if (!accessToken) throw new Error('Access Token required');
return callMCPTool(
accessToken,
args.tool.name,
args.arguments
);
},
onSuccess: (data) => {
setToolResult(data);
setToolError(null);
},
onError: (error: Error) => {
setToolError(error);
setToolResult(null);
},
});
// Add onToolSelect handler to each tool
const toolsData = React.useMemo(() => {
if (!mcpTools) return [];
return mcpTools.map((tool: MCPTool) => ({
...tool,
onToolSelect: (tool: MCPTool) => {
setSelectedTool(tool);
setToolResult(null);
setToolError(null);
}
}));
}, [mcpTools]);
// Filter tools based on search term
const filteredTools = React.useMemo(() => {
return toolsData.filter((tool: MCPTool) => {
const searchLower = searchTerm.toLowerCase();
return (
tool.name.toLowerCase().includes(searchLower) ||
tool.description.toLowerCase().includes(searchLower) ||
tool.mcp_info.server_name.toLowerCase().includes(searchLower)
);
});
}, [toolsData, searchTerm]);
// Handle tool call submission
const handleToolSubmit = (args: Record<string, any>) => {
if (!selectedTool) return;
executeTool({
tool: selectedTool,
arguments: args,
});
};
if (!accessToken || !userRole || !userID) {
return <div className="p-6 text-center text-gray-500">Missing required authentication parameters.</div>;
}
return (
<div className="w-full p-6">
<div className="flex items-center justify-between mb-4">
<h1 className="text-xl font-semibold">MCP Tools</h1>
</div>
<div className="bg-white rounded-lg shadow">
<div className="border-b px-6 py-4">
<div className="flex items-center justify-between">
<div className="relative w-64">
<input
type="text"
placeholder="Search tools..."
className="w-full px-3 py-2 pl-8 border rounded-md text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
/>
<svg
className="absolute left-2.5 top-2.5 h-4 w-4 text-gray-500"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
/>
</svg>
</div>
<div className="text-sm text-gray-500">
{filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""} available
</div>
</div>
</div>
<DataTableWrapper
columns={columns}
data={filteredTools}
isLoading={isLoadingTools}
/>
</div>
{/* Tool Test Panel - Show when a tool is selected */}
{selectedTool && (
<div className="fixed inset-0 bg-gray-800 bg-opacity-75 flex items-center justify-center z-50 p-4">
<ToolTestPanel
tool={selectedTool}
onSubmit={handleToolSubmit}
isLoading={isCallingTool}
result={toolResult}
error={toolError}
onClose={() => setSelectedTool(null)}
/>
</div>
)}
</div>
);
}

View file

@ -0,0 +1,71 @@
// Define the structure for tool input schema properties
export interface InputSchemaProperty {
type: string;
description?: string;
}
// Define the structure for the input schema of a tool
export interface InputSchema {
type: "object";
properties: Record<string, InputSchemaProperty>;
required?: string[];
}
// Define MCP provider info
export interface MCPInfo {
server_name: string;
logo_url?: string;
}
// Define the structure for a single MCP tool
export interface MCPTool {
name: string;
description: string;
inputSchema: InputSchema | string; // API returns string "tool_input_schema" or the actual schema
mcp_info: MCPInfo;
// Function to select a tool (added in the component)
onToolSelect?: (tool: MCPTool) => void;
}
// Define the response structure for the listMCPTools endpoint - now a flat array
export type ListMCPToolsResponse = MCPTool[];
// Define the argument structure for calling an MCP tool
export interface CallMCPToolArgs {
name: string;
arguments: Record<string, any> | null;
server_name?: string; // Now using server_name from mcp_info
}
// Define the possible content types in the response
export interface MCPTextContent {
type: "text";
text: string;
annotations?: any;
}
export interface MCPImageContent {
type: "image";
url?: string;
data?: string;
}
export interface MCPEmbeddedResource {
type: "embedded_resource";
resource_type?: string;
url?: string;
data?: any;
}
// Define the union type for the content array in the response
export type MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource;
// Define the response structure for the callMCPTool endpoint
export type CallMCPToolResponse = MCPContent[];
// Props for the main component
export interface MCPToolsViewerProps {
accessToken: string | null;
userRole: string | null;
userID: string | null;
}

View file

@ -4084,3 +4084,73 @@ export const updateInternalUserSettings = async (accessToken: string, settings:
throw error;
}
};
export const listMCPTools = async (accessToken: string) => {
try {
// Construct base URL
let url = proxyBaseUrl
? `${proxyBaseUrl}/mcp/tools/list`
: `/mcp/tools/list`;
console.log("Fetching MCP tools from:", url);
const response = await fetch(url, {
method: "GET",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log("Fetched MCP tools:", data);
return data;
} catch (error) {
console.error("Failed to fetch MCP tools:", error);
throw error;
}
};
export const callMCPTool = async (accessToken: string, toolName: string, toolArguments: Record<string, any>) => {
try {
// Construct base URL
let url = proxyBaseUrl
? `${proxyBaseUrl}/mcp/tools/call`
: `/mcp/tools/call`;
console.log("Calling MCP tool:", toolName, "with arguments:", toolArguments);
const response = await fetch(url, {
method: "POST",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
name: toolName,
arguments: toolArguments,
}),
});
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log("MCP tool call response:", data);
return data;
} catch (error) {
console.error("Failed to call MCP tool:", error);
throw error;
}
};

View file

@ -8,6 +8,19 @@ import { Tooltip } from "antd";
import { TimeCell } from "./time_cell";
import { Button } from "@tremor/react";
// Helper to get the appropriate logo URL
const getLogoUrl = (
row: LogEntry,
provider: string
) => {
// Check if mcp_tool_call_metadata exists and contains mcp_server_logo_url
if (row.metadata?.mcp_tool_call_metadata?.mcp_server_logo_url) {
return row.metadata.mcp_tool_call_metadata.mcp_server_logo_url;
}
// Fall back to default provider logo
return provider ? getProviderLogoAndName(provider).logo : '';
};
export type LogEntry = {
request_id: string;
api_key: string;
@ -177,7 +190,7 @@ export const columns: ColumnDef<LogEntry>[] = [
<div className="flex items-center space-x-2">
{provider && (
<img
src={getProviderLogoAndName(provider).logo}
src={getLogoUrl(row, provider)}
alt=""
className="w-4 h-4"
onError={(e) => {