mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor location of proxy
This commit is contained in:
parent
baa5564f95
commit
ce58c53ff1
413 changed files with 2087 additions and 2088 deletions
309
litellm_proxy/_experimental/mcp_server/server.py
Normal file
309
litellm_proxy/_experimental/mcp_server/server.py
Normal file
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
LiteLLM MCP Server Routes
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import ConfigDict, ValidationError
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_TOOL_NAME_PREFIX
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm_proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||
from litellm.utils import client
|
||||
|
||||
# Check if MCP is available
|
||||
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
|
||||
# We're making this conditional import to avoid breaking users who use python 3.8.
|
||||
try:
|
||||
from mcp.server import Server
|
||||
|
||||
MCP_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
verbose_logger.debug(f"MCP module not found: {e}")
|
||||
MCP_AVAILABLE = False
|
||||
router = APIRouter(
|
||||
prefix="/mcp",
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
|
||||
if MCP_AVAILABLE:
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.types import EmbeddedResource as MCPEmbeddedResource
|
||||
from mcp.types import ImageContent as MCPImageContent
|
||||
from mcp.types import TextContent as MCPTextContent
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from .mcp_server_manager import global_mcp_server_manager
|
||||
from .sse_transport import SseServerTransport
|
||||
from .tool_registry import global_mcp_tool_registry
|
||||
|
||||
######################################################
|
||||
############ MCP Tools List REST API Response Object #
|
||||
# Defined here because we don't want to add `mcp` as a
|
||||
# required dependency for `litellm` pip package
|
||||
######################################################
|
||||
class ListMCPToolsRestAPIResponseObject(MCPTool):
|
||||
"""
|
||||
Object returned by the /tools/list REST API route.
|
||||
"""
|
||||
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
########################################################
|
||||
############ Initialize the MCP Server #################
|
||||
########################################################
|
||||
router = APIRouter(
|
||||
prefix="/mcp",
|
||||
tags=["mcp"],
|
||||
)
|
||||
server: Server = Server("litellm-mcp-server")
|
||||
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")
|
||||
|
||||
########################################################
|
||||
############### MCP Server Routes #######################
|
||||
########################################################
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
return await _list_mcp_tools()
|
||||
|
||||
async def _list_mcp_tools() -> List[MCPTool]:
|
||||
"""
|
||||
List all available tools
|
||||
"""
|
||||
tools = []
|
||||
for tool in global_mcp_tool_registry.list_tools():
|
||||
tools.append(
|
||||
MCPTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.input_schema,
|
||||
)
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools()
|
||||
)
|
||||
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools()
|
||||
verbose_logger.debug("SSE TOOLS: %s", sse_tools)
|
||||
if sse_tools is not None:
|
||||
tools.extend(sse_tools)
|
||||
return tools
|
||||
|
||||
@server.call_tool()
|
||||
async def mcp_server_tool_call(
|
||||
name: str, arguments: Dict[str, Any] | None
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
|
||||
Args:
|
||||
name (str): Name of the tool to call
|
||||
arguments (Dict[str, Any] | None): Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results
|
||||
|
||||
Raises:
|
||||
HTTPException: If tool not found or arguments missing
|
||||
"""
|
||||
# Validate arguments
|
||||
response = await call_mcp_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return response
|
||||
|
||||
@client
|
||||
async def call_mcp_tool(
|
||||
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""
|
||||
Call a specific tool with the provided arguments
|
||||
"""
|
||||
if arguments is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Request arguments are required"
|
||||
)
|
||||
|
||||
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
|
||||
_get_standard_logging_mcp_tool_call(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
if litellm_logging_obj:
|
||||
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
|
||||
standard_logging_mcp_tool_call
|
||||
)
|
||||
litellm_logging_obj.model_call_details["model"] = (
|
||||
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
|
||||
)
|
||||
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
||||
standard_logging_mcp_tool_call.get("mcp_server_name")
|
||||
)
|
||||
|
||||
# Try managed server tool first
|
||||
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping:
|
||||
return await _handle_managed_mcp_tool(name, arguments)
|
||||
|
||||
# Fall back to local tool registry
|
||||
return await _handle_local_mcp_tool(name, arguments)
|
||||
|
||||
def _get_standard_logging_mcp_tool_call(
|
||||
name: str,
|
||||
arguments: Dict[str, Any],
|
||||
) -> StandardLoggingMCPToolCall:
|
||||
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server:
|
||||
mcp_info = mcp_server.mcp_info or {}
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
mcp_server_name=mcp_info.get("server_name"),
|
||||
mcp_server_logo_url=mcp_info.get("logo_url"),
|
||||
)
|
||||
else:
|
||||
return StandardLoggingMCPToolCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
async def _handle_managed_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for managed server tools"""
|
||||
call_tool_result = await global_mcp_server_manager.call_tool(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result)
|
||||
return call_tool_result.content
|
||||
|
||||
async def _handle_local_mcp_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
|
||||
"""Handle tool execution for local registry tools"""
|
||||
tool = global_mcp_tool_registry.get_tool(name)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
|
||||
|
||||
try:
|
||||
result = tool.handler(**arguments)
|
||||
return [MCPTextContent(text=str(result), type="text")]
|
||||
except Exception as e:
|
||||
return [MCPTextContent(text=f"Error: {str(e)}", type="text")]
|
||||
|
||||
@router.get("/", response_class=StreamingResponse)
|
||||
async def handle_sse(request: Request):
|
||||
verbose_logger.info("new incoming SSE connection established")
|
||||
async with sse.connect_sse(request) as streams:
|
||||
try:
|
||||
await server.run(streams[0], streams[1], options)
|
||||
except BrokenResourceError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except ValidationError:
|
||||
pass
|
||||
except Exception:
|
||||
raise
|
||||
await request.close()
|
||||
|
||||
@router.post("/sse/messages")
|
||||
async def handle_messages(request: Request):
|
||||
verbose_logger.info("incoming SSE message received")
|
||||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||
await request.close()
|
||||
|
||||
########################################################
|
||||
############ MCP Server REST API Routes #################
|
||||
########################################################
|
||||
@router.get("/tools/list", dependencies=[Depends(user_api_key_auth)])
|
||||
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]:
|
||||
"""
|
||||
List all available tools with information about the server they belong to.
|
||||
|
||||
Example response:
|
||||
Tools:
|
||||
[
|
||||
{
|
||||
"name": "create_zap",
|
||||
"description": "Create a new zap",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "zapier",
|
||||
"logo_url": "https://www.zapier.com/logo.png",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "fetch_data",
|
||||
"description": "Fetch data from a URL",
|
||||
"inputSchema": "tool_input_schema",
|
||||
"mcp_info": {
|
||||
"server_name": "fetch",
|
||||
"logo_url": "https://www.fetch.com/logo.png",
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = []
|
||||
for server in global_mcp_server_manager.mcp_servers:
|
||||
try:
|
||||
tools = await global_mcp_server_manager._get_tools_from_server(server)
|
||||
for tool in tools:
|
||||
list_tools_result.append(
|
||||
ListMCPToolsRestAPIResponseObject(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.inputSchema,
|
||||
mcp_info=server.mcp_info,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||
continue
|
||||
return list_tools_result
|
||||
|
||||
@router.post("/tools/call", dependencies=[Depends(user_api_key_auth)])
|
||||
async def call_tool_rest_api(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
REST API to call a specific MCP tool with the provided arguments
|
||||
"""
|
||||
from litellm_proxy.proxy_server import add_litellm_data_to_request, proxy_config
|
||||
|
||||
data = await request.json()
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
return await call_mcp_tool(**data)
|
||||
|
||||
options = InitializationOptions(
|
||||
server_name="litellm-mcp-server",
|
||||
server_version="0.1.0",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue