mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #9642 from BerriAI/litellm_mcp_improvements_expose_sse_urls
[Feat] - MCP improvements, add support for using SSE MCP servers
This commit is contained in:
commit
f5c0afcf96
25 changed files with 1220 additions and 30 deletions
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
153
litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
"""
|
||||
MCP Client Manager
|
||||
|
||||
This class is responsible for managing MCP SSE clients.
|
||||
|
||||
This is a Proxy
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
|
||||
|
||||
|
||||
class MCPServerManager:
|
||||
def __init__(self):
|
||||
self.mcp_servers: List[MCPSSEServer] = []
|
||||
"""
|
||||
eg.
|
||||
[
|
||||
{
|
||||
"name": "zapier_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
},
|
||||
{
|
||||
"name": "google_drive_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
|
||||
"""
|
||||
{
|
||||
"gmail_send_email": "zapier_mcp_server",
|
||||
}
|
||||
"""
|
||||
|
||||
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
|
||||
"""
|
||||
Load the MCP Servers from the config
|
||||
"""
|
||||
for server_name, server_config in mcp_servers_config.items():
|
||||
_mcp_info: dict = server_config.get("mcp_info", None) or {}
|
||||
mcp_info = MCPInfo(**_mcp_info)
|
||||
mcp_info["server_name"] = server_name
|
||||
self.mcp_servers.append(
|
||||
MCPSSEServer(
|
||||
name=server_name,
|
||||
url=server_config["url"],
|
||||
mcp_info=mcp_info,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
|
||||
)
|
||||
|
||||
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""
|
||||
List all tools available across all MCP Servers.
|
||||
|
||||
Returns:
|
||||
List[MCPTool]: Combined list of tools from all servers
|
||||
"""
|
||||
list_tools_result: List[MCPTool] = []
|
||||
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
|
||||
|
||||
for server in self.mcp_servers:
|
||||
tools = await self._get_tools_from_server(server)
|
||||
list_tools_result.extend(tools)
|
||||
|
||||
return list_tools_result
|
||||
|
||||
async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
|
||||
"""
|
||||
Helper method to get tools from a single MCP server.
|
||||
|
||||
Args:
|
||||
server (MCPSSEServer): The server to query tools from
|
||||
|
||||
Returns:
|
||||
List[MCPTool]: List of tools available on the server
|
||||
"""
|
||||
verbose_logger.debug(f"Connecting to url: {server.url}")
|
||||
|
||||
async with sse_client(url=server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
|
||||
|
||||
# Update tool to server mapping
|
||||
for tool in tools_result.tools:
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||
|
||||
return tools_result.tools
|
||||
|
||||
def initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||
"""
|
||||
On startup, initialize the tool name to MCP server name mapping
|
||||
"""
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
asyncio.create_task(
|
||||
self._initialize_tool_name_to_mcp_server_name_mapping()
|
||||
)
|
||||
except RuntimeError as e: # no running event loop
|
||||
verbose_logger.exception(
|
||||
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
|
||||
)
|
||||
|
||||
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
|
||||
"""
|
||||
Call list_tools for each server and update the tool name to MCP server name mapping
|
||||
"""
|
||||
for server in self.mcp_servers:
|
||||
tools = await self._get_tools_from_server(server)
|
||||
for tool in tools:
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
|
||||
|
||||
async def call_tool(self, name: str, arguments: Dict[str, Any]):
|
||||
"""
|
||||
Call a tool with the given name and arguments
|
||||
"""
|
||||
mcp_server = self._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server is None:
|
||||
raise ValueError(f"Tool {name} not found")
|
||||
async with sse_client(url=mcp_server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
return await session.call_tool(name, arguments)
|
||||
|
||||
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
|
||||
"""
|
||||
Get the MCP Server from the tool name
|
||||
"""
|
||||
if tool_name in self.tool_name_to_mcp_server_name_mapping:
|
||||
for server in self.mcp_servers:
|
||||
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
|
||||
return server
|
||||
return None
|
||||
|
||||
|
||||
global_mcp_server_manager: MCPServerManager = MCPServerManager()
|
|
@ -3,14 +3,21 @@ LiteLLM MCP Server Routes
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ConfigDict, ValidationError
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_TOOL_NAME_PREFIX
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||
from litellm.utils import client
|
||||
|
||||
# Check if MCP is available
|
||||
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
||||
|
@ -36,9 +43,23 @@ if MCP_AVAILABLE:
|
|||
from mcp.types import TextContent as MCPTextContent
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from .mcp_server_manager import global_mcp_server_manager
|
||||
from .sse_transport import SseServerTransport
|
||||
from .tool_registry import global_mcp_tool_registry
|
||||
|
||||
######################################################
|
||||
############ MCP Tools List REST API Response Object #
|
||||
# Defined here because we don't want to add `mcp` as a
|
||||
# required dependency for `litellm` pip package
|
||||
######################################################
|
||||
class ListMCPToolsRestAPIResponseObject(MCPTool):
|
||||
"""
|
||||
Object returned by the /tools/list REST API route.
|
||||
"""
|
||||
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
########################################################
|
||||
############ Initialize the MCP Server #################
|
||||
########################################################
|
||||
|
@ -52,9 +73,14 @@ if MCP_AVAILABLE:
|
|||
########################################################
|
||||
############### MCP Server Routes #######################
|
||||
########################################################
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
return await _list_mcp_tools()
|
||||
|
||||
async def _list_mcp_tools() -> List[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
|
@ -67,24 +93,116 @@ if MCP_AVAILABLE:
|
|||
inputSchema=tool.input_schema,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
|
||||
)
|
||||
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
|
||||
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
|
||||
if sse_tools is not None:
|
||||
tools.extend(sse_tools)
|
||||
return tools
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
async def mcp_server_tool_call(
|
||||
name: str, arguments: Dict[str, Any] | None
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
|
||||
Args:
|
||||
name (str): Name of the tool to call
|
||||
arguments (Dict[str, Any] | None): Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
|
||||
|
||||
Raises:
|
||||
HTTPException: If tool not found or arguments missing
|
||||
"""
|
||||
# Validate arguments
|
||||
response = await call_mcp_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return response
|
||||
|
||||
@client
|
||||
async def call_mcp_tool(
|
||||
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
"""
|
||||
tool = global_mcp_tool_registry.get_tool(name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||
if arguments is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Request arguments are required"
|
||||
)
|
||||
|
||||
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
|
||||
_get_standard_logging_mcp_tool_call(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
if litellm_logging_obj:
|
||||
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
|
||||
standard_logging_mcp_tool_call
|
||||
)
|
||||
litellm_logging_obj.model_call_details["model"] = (
|
||||
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
|
||||
)
|
||||
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
||||
standard_logging_mcp_tool_call.get("mcp_server_name")
|
||||
)
|
||||
|
||||
# Try managed server tool first
|
||||
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
|
||||
return await _handle_managed_mcp_tool(name, arguments)
|
||||
|
||||
# Fall back to local tool registry
|
||||
return await _handle_local_mcp_tool(name, arguments)
|
||||
|
||||
def _get_standard_logging_mcp_tool_call(
|
||||
name: str,
|
||||
arguments: Dict[str, Any],
|
||||
) -> StandardLoggingMCPToolCall:
|
||||
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server:
|
||||
mcp_info = mcp_server.mcp_info or {}
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
mcp_server_name=mcp_info.get("server_name"),
|
||||
mcp_server_logo_url=mcp_info.get("logo_url"),
|
||||
)
|
||||
else:
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
async def _handle_managed_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for managed server tools"""
|
||||
call_tool_result = await global_mcp_server_manager.call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
|
||||
return call_tool_result.content
|
||||
|
||||
async def _handle_local_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for local registry tools"""
|
||||
tool = global_mcp_tool_registry.get_tool(name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||
|
||||
try:
|
||||
result = tool.handler(**arguments)
|
||||
return [MCPTextContent(text=str(result), type="text")]
|
||||
|
@ -113,6 +231,74 @@ if MCP_AVAILABLE:
|
|||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||
await request.close()
|
||||
|
||||
########################################################
|
||||
############ MCP Server REST API Routes #################
|
||||
########################################################
|
||||
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
|
||||
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
|
||||
"""
|
||||
List all available tools with information about the server they belong to.
|
||||
|
||||
Example response:
|
||||
Tools:
|
||||
[
|
||||
{
|
||||
"name": "create_zap",
|
||||
"description": "Create a new zap",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "zapier",
|
||||
"logo_url": "https://www.zapier.com/logo.png",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fetch_data",
|
||||
"description": "Fetch data from a URL",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "fetch",
|
||||
"logo_url": "https://www.fetch.com/logo.png",
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
|
||||
for server in global_mcp_server_manager.mcp_servers:
|
||||
try:
|
||||
tools = await global_mcp_server_manager._get_tools_from_server(server)
|
||||
for tool in tools:
|
||||
list_tools_result.append(
|
||||
ListMCPToolsRestAPIResponseObject(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.inputSchema,
|
||||
mcp_info=server.mcp_info,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||
continue
|
||||
return list_tools_result
|
||||
|
||||
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
|
||||
async def call_tool_rest_api(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
REST API to call a specific MCP tool with the provided arguments
|
||||
"""
|
||||
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config
|
||||
|
||||
data = await request.json()
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
return await call_mcp_tool(**data)
|
||||
|
||||
options = InitializationOptions(
|
||||
server_name="litellm-mcp-server",
|
||||
server_version="0.1.0",
|
||||
|
|
|
@ -27,6 +27,7 @@ from litellm.types.utils import (
|
|||
ModelResponse,
|
||||
ProviderField,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
StandardLoggingPayloadStatus,
|
||||
StandardPassThroughResponseObject,
|
||||
|
@ -1928,6 +1929,7 @@ class SpendLogsMetadata(TypedDict):
|
|||
] # special param to log k,v pairs to spendlogs for a call
|
||||
requester_ip_address: Optional[str]
|
||||
applied_guardrails: Optional[List[str]]
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall]
|
||||
status: StandardLoggingPayloadStatus
|
||||
proxy_server_request: Optional[str]
|
||||
batch_models: Optional[List[str]]
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model: openai/gpt-4o
|
||||
|
||||
mcp_servers:
|
||||
{
|
||||
"Zapier MCP": {
|
||||
"url": "os.environ/ZAPIER_MCP_SERVER_URL",
|
||||
"mcp_info": {
|
||||
"logo_url": "https://espysys.com/wp-content/uploads/2024/08/zapier-logo.webp",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1960,6 +1960,14 @@ class ProxyConfig:
|
|||
if mcp_tools_config:
|
||||
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
||||
|
||||
mcp_servers_config = config.get("mcp_servers", None)
|
||||
if mcp_servers_config:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
global_mcp_server_manager.load_servers_from_config(mcp_servers_config)
|
||||
|
||||
## CREDENTIALS
|
||||
credential_list_dict = self.load_credential_list(config=config)
|
||||
litellm.credential_list = credential_list_dict
|
||||
|
|
|
@ -13,7 +13,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
|
@ -38,6 +38,7 @@ def _get_spend_logs_metadata(
|
|||
metadata: Optional[dict],
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
batch_models: Optional[List[str]] = None,
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||
) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
|
@ -55,6 +56,7 @@ def _get_spend_logs_metadata(
|
|||
error_information=None,
|
||||
proxy_server_request=None,
|
||||
batch_models=None,
|
||||
mcp_tool_call_metadata=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
|
@ -71,6 +73,7 @@ def _get_spend_logs_metadata(
|
|||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
clean_metadata["batch_models"] = batch_models
|
||||
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
||||
return clean_metadata
|
||||
|
||||
|
||||
|
@ -200,6 +203,11 @@ def get_logging_payload( # noqa: PLR0915
|
|||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
mcp_tool_call_metadata=(
|
||||
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue