mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
309 lines
11 KiB
Python
309 lines
11 KiB
Python
"""
|
|
LiteLLM MCP Server Routes
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from anyio import BrokenResourceError
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import ConfigDict, ValidationError
|
|
|
|
from litellm._logging import verbose_logger
|
|
from litellm.constants import MCP_TOOL_NAME_PREFIX
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
|
from litellm.types.utils import StandardLoggingMCPToolCall
|
|
from litellm.utils import client
|
|
|
|
# Check if MCP is available
|
|
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
|
# We're making this conditional import to avoid breaking users who use python 3.8.
|
|
try:
|
|
from mcp.server import Server
|
|
|
|
MCP_AVAILABLE = True
|
|
except ImportError as e:
|
|
verbose_logger.debug(f"MCP module not found: {e}")
|
|
MCP_AVAILABLE = False
|
|
router = APIRouter(
|
|
prefix="/mcp",
|
|
tags=["mcp"],
|
|
)
|
|
|
|
|
|
if MCP_AVAILABLE:
|
|
from mcp.server import NotificationOptions, Server
|
|
from mcp.server.models import InitializationOptions
|
|
from mcp.types import EmbeddedResource as MCPEmbeddedResource
|
|
from mcp.types import ImageContent as MCPImageContent
|
|
from mcp.types import TextContent as MCPTextContent
|
|
from mcp.types import Tool as MCPTool
|
|
|
|
from .mcp_server_manager import global_mcp_server_manager
|
|
from .sse_transport import SseServerTransport
|
|
from .tool_registry import global_mcp_tool_registry
|
|
|
|
######################################################
|
|
############ MCP Tools List REST API Response Object #
|
|
# Defined here because we don't want to add `mcp` as a
|
|
# required dependency for `litellm` pip package
|
|
######################################################
|
|
class ListMCPToolsRestAPIResponseObject(MCPTool):
|
|
"""
|
|
Object returned by the /tools/list REST API route.
|
|
"""
|
|
|
|
mcp_info: Optional[MCPInfo] = None
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
########################################################
|
|
############ Initialize the MCP Server #################
|
|
########################################################
|
|
router = APIRouter(
|
|
prefix="/mcp",
|
|
tags=["mcp"],
|
|
)
|
|
server: Server = Server("litellm-mcp-server")
|
|
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")
|
|
|
|
########################################################
|
|
############### MCP Server Routes #######################
|
|
########################################################
|
|
@server.list_tools()
|
|
async def list_tools() -> list[MCPTool]:
|
|
"""
|
|
List all available tools
|
|
"""
|
|
return await _list_mcp_tools()
|
|
|
|
async def _list_mcp_tools() -> List[MCPTool]:
|
|
"""
|
|
List all available tools
|
|
"""
|
|
tools = []
|
|
for tool in global_mcp_tool_registry.list_tools():
|
|
tools.append(
|
|
MCPTool(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
inputSchema=tool.input_schema,
|
|
)
|
|
)
|
|
verbose_logger.debug(
|
|
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
|
|
)
|
|
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
|
|
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
|
|
if sse_tools is not None:
|
|
tools.extend(sse_tools)
|
|
return tools
|
|
|
|
@server.call_tool()
|
|
async def mcp_server_tool_call(
|
|
name: str, arguments: Dict[str, Any] | None
|
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
|
"""
|
|
Call a specific tool with the provided arguments
|
|
|
|
Args:
|
|
name (str): Name of the tool to call
|
|
arguments (Dict[str, Any] | None): Arguments to pass to the tool
|
|
|
|
Returns:
|
|
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
|
|
|
|
Raises:
|
|
HTTPException: If tool not found or arguments missing
|
|
"""
|
|
# Validate arguments
|
|
response = await call_mcp_tool(
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
return response
|
|
|
|
@client
|
|
async def call_mcp_tool(
|
|
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
|
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
|
"""
|
|
Call a specific tool with the provided arguments
|
|
"""
|
|
if arguments is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="Request arguments are required"
|
|
)
|
|
|
|
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
|
|
_get_standard_logging_mcp_tool_call(
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
)
|
|
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
|
"litellm_logging_obj", None
|
|
)
|
|
if litellm_logging_obj:
|
|
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
|
|
standard_logging_mcp_tool_call
|
|
)
|
|
litellm_logging_obj.model_call_details["model"] = (
|
|
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
|
|
)
|
|
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
|
standard_logging_mcp_tool_call.get("mcp_server_name")
|
|
)
|
|
|
|
# Try managed server tool first
|
|
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
|
|
return await _handle_managed_mcp_tool(name, arguments)
|
|
|
|
# Fall back to local tool registry
|
|
return await _handle_local_mcp_tool(name, arguments)
|
|
|
|
def _get_standard_logging_mcp_tool_call(
|
|
name: str,
|
|
arguments: Dict[str, Any],
|
|
) -> StandardLoggingMCPToolCall:
|
|
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
|
|
if mcp_server:
|
|
mcp_info = mcp_server.mcp_info or {}
|
|
return StandardLoggingMCPToolCall(
|
|
name=name,
|
|
arguments=arguments,
|
|
mcp_server_name=mcp_info.get("server_name"),
|
|
mcp_server_logo_url=mcp_info.get("logo_url"),
|
|
)
|
|
else:
|
|
return StandardLoggingMCPToolCall(
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
|
|
async def _handle_managed_mcp_tool(
|
|
name: str, arguments: Dict[str, Any]
|
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
|
"""Handle tool execution for managed server tools"""
|
|
call_tool_result = await global_mcp_server_manager.call_tool(
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
|
|
return call_tool_result.content
|
|
|
|
async def _handle_local_mcp_tool(
|
|
name: str, arguments: Dict[str, Any]
|
|
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
|
"""Handle tool execution for local registry tools"""
|
|
tool = global_mcp_tool_registry.get_tool(name)
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
|
|
|
try:
|
|
result = tool.handler(**arguments)
|
|
return [MCPTextContent(text=str(result), type="text")]
|
|
except Exception as e:
|
|
return [MCPTextContent(text=f"Error: {str(e)}", type="text")]
|
|
|
|
@router.get("/", response_class=StreamingResponse)
|
|
async def handle_sse(request: Request):
|
|
verbose_logger.info("new incoming SSE connection established")
|
|
async with sse.connect_sse(request) as streams:
|
|
try:
|
|
await server.run(streams[0], streams[1], options)
|
|
except BrokenResourceError:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except ValidationError:
|
|
pass
|
|
except Exception:
|
|
raise
|
|
await request.close()
|
|
|
|
@router.post("/sse/messages")
|
|
async def handle_messages(request: Request):
|
|
verbose_logger.info("incoming SSE message received")
|
|
await sse.handle_post_message(request.scope, request.receive, request._send)
|
|
await request.close()
|
|
|
|
########################################################
|
|
############ MCP Server REST API Routes #################
|
|
########################################################
|
|
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
|
|
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
|
|
"""
|
|
List all available tools with information about the server they belong to.
|
|
|
|
Example response:
|
|
Tools:
|
|
[
|
|
{
|
|
"name": "create_zap",
|
|
"description": "Create a new zap",
|
|
"inputSchema": "tool_input_schema",
|
|
"mcp_info": {
|
|
"server_name": "zapier",
|
|
"logo_url": "https://www.zapier.com/logo.png",
|
|
}
|
|
},
|
|
{
|
|
"name": "fetch_data",
|
|
"description": "Fetch data from a URL",
|
|
"inputSchema": "tool_input_schema",
|
|
"mcp_info": {
|
|
"server_name": "fetch",
|
|
"logo_url": "https://www.fetch.com/logo.png",
|
|
}
|
|
}
|
|
]
|
|
"""
|
|
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
|
|
for server in global_mcp_server_manager.mcp_servers:
|
|
try:
|
|
tools = await global_mcp_server_manager._get_tools_from_server(server)
|
|
for tool in tools:
|
|
list_tools_result.append(
|
|
ListMCPToolsRestAPIResponseObject(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
inputSchema=tool.inputSchema,
|
|
mcp_info=server.mcp_info,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
|
continue
|
|
return list_tools_result
|
|
|
|
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
|
|
async def call_tool_rest_api(
|
|
request: Request,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
REST API to call a specific MCP tool with the provided arguments
|
|
"""
|
|
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config
|
|
|
|
data = await request.json()
|
|
data = await add_litellm_data_to_request(
|
|
data=data,
|
|
request=request,
|
|
user_api_key_dict=user_api_key_dict,
|
|
proxy_config=proxy_config,
|
|
)
|
|
return await call_mcp_tool(**data)
|
|
|
|
options = InitializationOptions(
|
|
server_name="litellm-mcp-server",
|
|
server_version="0.1.0",
|
|
capabilities=server.get_capabilities(
|
|
notification_options=NotificationOptions(),
|
|
experimental_capabilities={},
|
|
),
|
|
)
|