fix mcp import

This commit is contained in:
Ishaan Jaff 2025-03-24 21:02:54 -07:00
parent df1789902d
commit dcc2edbd4d
2 changed files with 95 additions and 86 deletions

View file

@ -1062,7 +1062,6 @@ jobs:
pip install jinja2 pip install jinja2
pip install "tokenizers==0.20.0" pip install "tokenizers==0.20.0"
pip install "uvloop==0.21.0" pip install "uvloop==0.21.0"
pip install "mcp==1.5.0"
pip install jsonschema pip install jsonschema
- run: - run:
name: Run tests name: Run tests

View file

@ -3,108 +3,118 @@ LiteLLM MCP Server Routes
""" """
import asyncio import asyncio
import logging
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from anyio import BrokenResourceError from anyio import BrokenResourceError
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from mcp.types import EmbeddedResource as MCPEmbeddedResource
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool
from pydantic import ValidationError from pydantic import ValidationError
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.proxy._experimental.mcp_server.tool_registry import (
global_mcp_tool_registry,
)
from .sse_transport import SseServerTransport # Check if MCP is available
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8
# We're making this conditional import to avoid breaking users who use python 3.8.
try:
import mcp
######################################################## MCP_AVAILABLE = True
############ Initialize the MCP Server ################# except ImportError as e:
######################################################## verbose_logger.debug(f"MCP module not found: {e}")
router = APIRouter( MCP_AVAILABLE = False
prefix="/mcp",
tags=["mcp"],
)
server: Server = Server("litellm-mcp-server")
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")
########################################################
############### MCP Server Routes #######################
########################################################
@server.list_tools() if MCP_AVAILABLE:
async def list_tools() -> list[MCPTool]: from mcp.server import NotificationOptions, Server
""" from mcp.server.models import InitializationOptions
List all available tools from mcp.types import EmbeddedResource as MCPEmbeddedResource
""" from mcp.types import ImageContent as MCPImageContent
tools = [] from mcp.types import TextContent as MCPTextContent
for tool in global_mcp_tool_registry.list_tools(): from mcp.types import Tool as MCPTool
tools.append(
MCPTool( from .sse_transport import SseServerTransport
name=tool.name, from .tool_registry import global_mcp_tool_registry
description=tool.description,
inputSchema=tool.input_schema, ########################################################
############ Initialize the MCP Server #################
########################################################
router = APIRouter(
prefix="/mcp",
tags=["mcp"],
)
server: Server = Server("litellm-mcp-server")
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")
########################################################
############### MCP Server Routes #######################
########################################################
@server.list_tools()
async def list_tools() -> list[MCPTool]:
"""
List all available tools
"""
tools = []
for tool in global_mcp_tool_registry.list_tools():
tools.append(
MCPTool(
name=tool.name,
description=tool.description,
inputSchema=tool.input_schema,
)
) )
)
return tools return tools
@server.call_tool()
async def handle_call_tool(
name: str, arguments: Dict[str, Any] | None
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Call a specific tool with the provided arguments
"""
tool = global_mcp_tool_registry.get_tool(name)
if not tool:
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
if arguments is None:
raise HTTPException(
status_code=400, detail="Request arguments are required"
)
@server.call_tool()
async def handle_call_tool(
name: str, arguments: Dict[str, Any] | None
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Call a specific tool with the provided arguments
"""
tool = global_mcp_tool_registry.get_tool(name)
if not tool:
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found")
if arguments is None:
raise HTTPException(status_code=400, detail="Request arguments are required")
try:
result = tool.handler(**arguments)
return [MCPTextContent(text=str(result), type="text")]
except Exception as e:
return [MCPTextContent(text=f"Error: {str(e)}", type="text")]
@router.get("/", response_class=StreamingResponse)
async def handle_sse(request: Request):
verbose_logger.info("new incoming SSE connection established")
async with sse.connect_sse(request) as streams:
try: try:
await server.run(streams[0], streams[1], options) result = tool.handler(**arguments)
except BrokenResourceError: return [MCPTextContent(text=str(result), type="text")]
pass except Exception as e:
except asyncio.CancelledError: return [MCPTextContent(text=f"Error: {str(e)}", type="text")]
pass
except ValidationError:
pass
except Exception:
raise
await request.close()
@router.get("/", response_class=StreamingResponse)
async def handle_sse(request: Request):
verbose_logger.info("new incoming SSE connection established")
async with sse.connect_sse(request) as streams:
try:
await server.run(streams[0], streams[1], options)
except BrokenResourceError:
pass
except asyncio.CancelledError:
pass
except ValidationError:
pass
except Exception:
raise
await request.close()
@router.post("/sse/messages") @router.post("/sse/messages")
async def handle_messages(request: Request): async def handle_messages(request: Request):
verbose_logger.info("incoming SSE message received") verbose_logger.info("incoming SSE message received")
await sse.handle_post_message(request.scope, request.receive, request._send) await sse.handle_post_message(request.scope, request.receive, request._send)
await request.close() await request.close()
options = InitializationOptions(
options = InitializationOptions( server_name="litellm-mcp-server",
server_name="litellm-mcp-server", server_version="0.1.0",
server_version="0.1.0", capabilities=server.get_capabilities(
capabilities=server.get_capabilities( notification_options=NotificationOptions(),
notification_options=NotificationOptions(), experimental_capabilities={},
experimental_capabilities={}, ),
), )
)