mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
use global_mcp_server_manager
This commit is contained in:
parent
ecc31e7899
commit
c0d07987c4
2 changed files with 7 additions and 89 deletions
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
MCP Client Manager
|
||||
|
||||
This class is responsible for managing MCP SSE clients.
|
||||
|
||||
This is a Proxy
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class MCPSSEServer(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
client_session: Optional[ClientSession] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class MCPServerManager:
|
||||
def __init__(self, mcp_servers: List[MCPSSEServer]):
|
||||
self.mcp_servers: List[MCPSSEServer] = mcp_servers
|
||||
"""
|
||||
eg.
|
||||
[
|
||||
{
|
||||
"name": "zapier_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
},
|
||||
{
|
||||
"name": "google_drive_mcp_server",
|
||||
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
|
||||
"""
|
||||
{
|
||||
"gmail_send_email": "zapier_mcp_server",
|
||||
}
|
||||
"""
|
||||
|
||||
async def list_tools(self):
|
||||
"""
|
||||
List all tools available in all the MCP Servers
|
||||
"""
|
||||
for server in self.mcp_servers:
|
||||
async with sse_client(url=server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
server.client_session = session
|
||||
await server.client_session.initialize()
|
||||
list_tools_result = await server.client_session.list_tools()
|
||||
verbose_logger.debug(
|
||||
f"Tools from {server.name}: {list_tools_result}"
|
||||
)
|
||||
for tool in list_tools_result.tools:
|
||||
self.tool_name_to_mcp_server_name_mapping[tool.name] = (
|
||||
server.name
|
||||
)
|
||||
|
||||
async def call_tool(self, name: str, arguments: Dict[str, Any]):
|
||||
"""
|
||||
Call a tool with the given name and arguments
|
||||
"""
|
||||
mcp_server = self._get_mcp_server_from_tool_name(name)
|
||||
if mcp_server is None:
|
||||
raise ValueError(f"Tool {name} not found")
|
||||
async with sse_client(url=mcp_server.url) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
mcp_server.client_session = session
|
||||
await mcp_server.client_session.initialize()
|
||||
return await mcp_server.client_session.call_tool(name, arguments)
|
||||
|
||||
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
|
||||
"""
|
||||
Get the MCP Server from the tool name
|
||||
"""
|
||||
if tool_name in self.tool_name_to_mcp_server_name_mapping:
|
||||
for server in self.mcp_servers:
|
||||
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
|
||||
return server
|
||||
return None
|
|
@ -127,6 +127,9 @@ from litellm.litellm_core_utils.core_helpers import (
|
|||
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.server import router as mcp_router
|
||||
from litellm.proxy._experimental.mcp_server.tool_registry import (
|
||||
global_mcp_tool_registry,
|
||||
|
@ -1961,6 +1964,10 @@ class ProxyConfig:
|
|||
if mcp_tools_config:
|
||||
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
||||
|
||||
mcp_servers_config = config.get("mcp_servers", None)
|
||||
if mcp_servers_config:
|
||||
global_mcp_server_manager.load_servers_from_config(mcp_servers_config)
|
||||
|
||||
## CREDENTIALS
|
||||
credential_list_dict = self.load_credential_list(config=config)
|
||||
litellm.credential_list = credential_list_dict
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue