diff --git a/litellm/constants.py b/litellm/constants.py index e224b3d33e..d5e0215ebf 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -418,6 +418,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv ########################### Logging Callback Constants ########################### AZURE_STORAGE_MSFT_VERSION = "2019-07-07" +MCP_TOOL_NAME_PREFIX = "mcp_tool" ########################### LiteLLM Proxy Specific Constants ########################### ######################################################################################## diff --git a/litellm/experimental_mcp_client/tools.py b/litellm/experimental_mcp_client/tools.py index f4ebbf4af4..cdc26af4b7 100644 --- a/litellm/experimental_mcp_client/tools.py +++ b/litellm/experimental_mcp_client/tools.py @@ -1,5 +1,5 @@ import json -from typing import List, Literal, Union +from typing import Dict, List, Literal, Union from mcp import ClientSession from mcp.types import CallToolRequestParams as MCPCallToolRequestParams @@ -76,8 +76,8 @@ def _get_function_arguments(function: FunctionDefinition) -> dict: return arguments if isinstance(arguments, dict) else {} -def _transform_openai_tool_call_to_mcp_tool_call_request( - openai_tool: ChatCompletionMessageToolCall, +def transform_openai_tool_call_request_to_mcp_tool_call_request( + openai_tool: Union[ChatCompletionMessageToolCall, Dict], ) -> MCPCallToolRequestParams: """Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams.""" function = openai_tool["function"] @@ -100,8 +100,10 @@ async def call_openai_tool( Returns: The result of the MCP tool call. """ - mcp_tool_call_request_params = _transform_openai_tool_call_to_mcp_tool_call_request( - openai_tool=openai_tool, + mcp_tool_call_request_params = ( + transform_openai_tool_call_request_to_mcp_tool_call_request( + openai_tool=openai_tool, + ) ) return await call_mcp_tool( session=session, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index dcd3ae3a64..edb37b33a8 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -67,6 +67,7 @@ from litellm.types.utils import ( StandardCallbackDynamicParams, StandardLoggingAdditionalHeaders, StandardLoggingHiddenParams, + StandardLoggingMCPToolCall, StandardLoggingMetadata, StandardLoggingModelCostFailureDebugInformation, StandardLoggingModelInformation, @@ -1095,7 +1096,7 @@ class Logging(LiteLLMLoggingBaseClass): status="success", standard_built_in_tools_params=self.standard_built_in_tools_params, ) - elif isinstance(result, dict): # pass-through endpoints + elif isinstance(result, dict) or isinstance(result, list): ## STANDARDIZED LOGGING PAYLOAD self.model_call_details[ "standard_logging_object" @@ -3106,6 +3107,7 @@ class StandardLoggingPayloadSetup: litellm_params: Optional[dict] = None, prompt_integration: Optional[str] = None, applied_guardrails: Optional[List[str]] = None, + mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None, ) -> StandardLoggingMetadata: """ Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. @@ -3152,6 +3154,7 @@ class StandardLoggingPayloadSetup: user_api_key_end_user_id=None, prompt_management_metadata=prompt_management_metadata, applied_guardrails=applied_guardrails, + mcp_tool_call_metadata=mcp_tool_call_metadata, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys @@ -3478,6 +3481,7 @@ def get_standard_logging_object_payload( litellm_params=litellm_params, prompt_integration=kwargs.get("prompt_integration", None), applied_guardrails=kwargs.get("applied_guardrails", None), + mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None), ) _request_body = proxy_server_request.get("body", {}) @@ -3617,6 +3621,7 @@ def get_standard_logging_metadata( user_api_key_end_user_id=None, prompt_management_metadata=None, applied_guardrails=None, + mcp_tool_call_metadata=None, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py new file mode 100644 index 0000000000..df9ae0ea57 --- /dev/null +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -0,0 +1,153 @@ +""" +MCP Client Manager + +This class is responsible for managing MCP SSE clients. + +This is a Proxy +""" + +import asyncio +import json +from typing import Any, Dict, List, Optional + +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.types import Tool as MCPTool + +from litellm._logging import verbose_logger +from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer + + +class MCPServerManager: + def __init__(self): + self.mcp_servers: List[MCPSSEServer] = [] + """ + eg. + [ + { + "name": "zapier_mcp_server", + "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" + }, + { + "name": "google_drive_mcp_server", + "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" + } + ] + """ + + self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {} + """ + { + "gmail_send_email": "zapier_mcp_server", + } + """ + + def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]): + """ + Load the MCP Servers from the config + """ + for server_name, server_config in mcp_servers_config.items(): + _mcp_info: dict = server_config.get("mcp_info", None) or {} + mcp_info = MCPInfo(**_mcp_info) + mcp_info["server_name"] = server_name + self.mcp_servers.append( + MCPSSEServer( + name=server_name, + url=server_config["url"], + mcp_info=mcp_info, + ) + ) + verbose_logger.debug( + f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}" + ) + + self.initialize_tool_name_to_mcp_server_name_mapping() + + async def list_tools(self) -> List[MCPTool]: + """ + List all tools available across all MCP Servers. + + Returns: + List[MCPTool]: Combined list of tools from all servers + """ + list_tools_result: List[MCPTool] = [] + verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS") + + for server in self.mcp_servers: + tools = await self._get_tools_from_server(server) + list_tools_result.extend(tools) + + return list_tools_result + + async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]: + """ + Helper method to get tools from a single MCP server. + + Args: + server (MCPSSEServer): The server to query tools from + + Returns: + List[MCPTool]: List of tools available on the server + """ + verbose_logger.debug(f"Connecting to url: {server.url}") + + async with sse_client(url=server.url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + tools_result = await session.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools_result}") + + # Update tool to server mapping + for tool in tools_result.tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + + return tools_result.tools + + def initialize_tool_name_to_mcp_server_name_mapping(self): + """ + On startup, initialize the tool name to MCP server name mapping + """ + try: + if asyncio.get_running_loop(): + asyncio.create_task( + self._initialize_tool_name_to_mcp_server_name_mapping() + ) + except RuntimeError as e: # no running event loop + verbose_logger.exception( + f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}" + ) + + async def _initialize_tool_name_to_mcp_server_name_mapping(self): + """ + Call list_tools for each server and update the tool name to MCP server name mapping + """ + for server in self.mcp_servers: + tools = await self._get_tools_from_server(server) + for tool in tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name + + async def call_tool(self, name: str, arguments: Dict[str, Any]): + """ + Call a tool with the given name and arguments + """ + mcp_server = self._get_mcp_server_from_tool_name(name) + if mcp_server is None: + raise ValueError(f"Tool {name} not found") + async with sse_client(url=mcp_server.url) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + return await session.call_tool(name, arguments) + + def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]: + """ + Get the MCP Server from the tool name + """ + if tool_name in self.tool_name_to_mcp_server_name_mapping: + for server in self.mcp_servers: + if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]: + return server + return None + + +global_mcp_server_manager: MCPServerManager = MCPServerManager() diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index f617312d5a..fe1eccb048 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -3,14 +3,21 @@ LiteLLM MCP Server Routes """ import asyncio -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from anyio import BrokenResourceError -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse -from pydantic import ValidationError +from pydantic import ConfigDict, ValidationError from litellm._logging import verbose_logger +from litellm.constants import MCP_TOOL_NAME_PREFIX +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.mcp_server.mcp_server_manager import MCPInfo +from litellm.types.utils import StandardLoggingMCPToolCall +from litellm.utils import client # Check if MCP is available # "mcp" requires python 3.10 or higher, but several litellm users use python 3.8 @@ -36,9 +43,23 @@ if MCP_AVAILABLE: from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool + from .mcp_server_manager import global_mcp_server_manager from .sse_transport import SseServerTransport from .tool_registry import global_mcp_tool_registry + ###################################################### + ############ MCP Tools List REST API Response Object # + # Defined here because we don't want to add `mcp` as a + # required dependency for `litellm` pip package + ###################################################### + class ListMCPToolsRestAPIResponseObject(MCPTool): + """ + Object returned by the /tools/list REST API route. + """ + + mcp_info: Optional[MCPInfo] = None + model_config = ConfigDict(arbitrary_types_allowed=True) + ######################################################## ############ Initialize the MCP Server ################# ######################################################## @@ -52,9 +73,14 @@ if MCP_AVAILABLE: ######################################################## ############### MCP Server Routes ####################### ######################################################## - @server.list_tools() async def list_tools() -> list[MCPTool]: + """ + List all available tools + """ + return await _list_mcp_tools() + + async def _list_mcp_tools() -> List[MCPTool]: """ List all available tools """ @@ -67,24 +93,116 @@ if MCP_AVAILABLE: inputSchema=tool.input_schema, ) ) - + verbose_logger.debug( + "GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools() + ) + sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools() + verbose_logger.debug("SSE TOOLS: %s", sse_tools) + if sse_tools is not None: + tools.extend(sse_tools) return tools @server.call_tool() - async def handle_call_tool( + async def mcp_server_tool_call( name: str, arguments: Dict[str, Any] | None ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: """ Call a specific tool with the provided arguments + + Args: + name (str): Name of the tool to call + arguments (Dict[str, Any] | None): Arguments to pass to the tool + + Returns: + List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results + + Raises: + HTTPException: If tool not found or arguments missing + """ + # Validate arguments + response = await call_mcp_tool( + name=name, + arguments=arguments, + ) + return response + + @client + async def call_mcp_tool( + name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: + """ + Call a specific tool with the provided arguments """ - tool = global_mcp_tool_registry.get_tool(name) - if not tool: - raise HTTPException(status_code=404, detail=f"Tool '{name}' not found") if arguments is None: raise HTTPException( status_code=400, detail="Request arguments are required" ) + standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = ( + _get_standard_logging_mcp_tool_call( + name=name, + arguments=arguments, + ) + ) + litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( + "litellm_logging_obj", None + ) + if litellm_logging_obj: + litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( + standard_logging_mcp_tool_call + ) + litellm_logging_obj.model_call_details["model"] = ( + f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}" + ) + litellm_logging_obj.model_call_details["custom_llm_provider"] = ( + standard_logging_mcp_tool_call.get("mcp_server_name") + ) + + # Try managed server tool first + if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping: + return await _handle_managed_mcp_tool(name, arguments) + + # Fall back to local tool registry + return await _handle_local_mcp_tool(name, arguments) + + def _get_standard_logging_mcp_tool_call( + name: str, + arguments: Dict[str, Any], + ) -> StandardLoggingMCPToolCall: + mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name) + if mcp_server: + mcp_info = mcp_server.mcp_info or {} + return StandardLoggingMCPToolCall( + name=name, + arguments=arguments, + mcp_server_name=mcp_info.get("server_name"), + mcp_server_logo_url=mcp_info.get("logo_url"), + ) + else: + return StandardLoggingMCPToolCall( + name=name, + arguments=arguments, + ) + + async def _handle_managed_mcp_tool( + name: str, arguments: Dict[str, Any] + ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: + """Handle tool execution for managed server tools""" + call_tool_result = await global_mcp_server_manager.call_tool( + name=name, + arguments=arguments, + ) + verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result) + return call_tool_result.content + + async def _handle_local_mcp_tool( + name: str, arguments: Dict[str, Any] + ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: + """Handle tool execution for local registry tools""" + tool = global_mcp_tool_registry.get_tool(name) + if not tool: + raise HTTPException(status_code=404, detail=f"Tool '{name}' not found") + try: result = tool.handler(**arguments) return [MCPTextContent(text=str(result), type="text")] @@ -113,6 +231,74 @@ if MCP_AVAILABLE: await sse.handle_post_message(request.scope, request.receive, request._send) await request.close() + ######################################################## + ############ MCP Server REST API Routes ################# + ######################################################## + @router.get("/tools/list", dependencies=[Depends(user_api_key_auth)]) + async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]: + """ + List all available tools with information about the server they belong to. + + Example response: + Tools: + [ + { + "name": "create_zap", + "description": "Create a new zap", + "inputSchema": "tool_input_schema", + "mcp_info": { + "server_name": "zapier", + "logo_url": "https://www.zapier.com/logo.png", + } + }, + { + "name": "fetch_data", + "description": "Fetch data from a URL", + "inputSchema": "tool_input_schema", + "mcp_info": { + "server_name": "fetch", + "logo_url": "https://www.fetch.com/logo.png", + } + } + ] + """ + list_tools_result: List[ListMCPToolsRestAPIResponseObject] = [] + for server in global_mcp_server_manager.mcp_servers: + try: + tools = await global_mcp_server_manager._get_tools_from_server(server) + for tool in tools: + list_tools_result.append( + ListMCPToolsRestAPIResponseObject( + name=tool.name, + description=tool.description, + inputSchema=tool.inputSchema, + mcp_info=server.mcp_info, + ) + ) + except Exception as e: + verbose_logger.exception(f"Error getting tools from {server.name}: {e}") + continue + return list_tools_result + + @router.post("/tools/call", dependencies=[Depends(user_api_key_auth)]) + async def call_tool_rest_api( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + """ + REST API to call a specific MCP tool with the provided arguments + """ + from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config + + data = await request.json() + data = await add_litellm_data_to_request( + data=data, + request=request, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + return await call_mcp_tool(**data) + options = InitializationOptions( server_name="litellm-mcp-server", server_version="0.1.0", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 16b45f3837..7f13717e29 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -27,6 +27,7 @@ from litellm.types.utils import ( ModelResponse, ProviderField, StandardCallbackDynamicParams, + StandardLoggingMCPToolCall, StandardLoggingPayloadErrorInformation, StandardLoggingPayloadStatus, StandardPassThroughResponseObject, @@ -1928,6 +1929,7 @@ class SpendLogsMetadata(TypedDict): ] # special param to log k,v pairs to spendlogs for a call requester_ip_address: Optional[str] applied_guardrails: Optional[List[str]] + mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] status: StandardLoggingPayloadStatus proxy_server_request: Optional[str] batch_models: Optional[List[str]] diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 106003f996..3956eab23f 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,7 +1,14 @@ model_list: - - model_name: fake-openai-endpoint + - model_name: gpt-4o litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + model: openai/gpt-4o +mcp_servers: + { + "Zapier MCP": { + "url": "os.environ/ZAPIER_MCP_SERVER_URL", + "mcp_info": { + "logo_url": "https://espysys.com/wp-content/uploads/2024/08/zapier-logo.webp", + } + } + } diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f59d117181..99a8965f42 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1960,6 +1960,14 @@ class ProxyConfig: if mcp_tools_config: global_mcp_tool_registry.load_tools_from_config(mcp_tools_config) + mcp_servers_config = config.get("mcp_servers", None) + if mcp_servers_config: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + global_mcp_server_manager.load_servers_from_config(mcp_servers_config) + ## CREDENTIALS credential_list_dict = self.load_credential_list(config=config) litellm.credential_list = credential_list_dict diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 6e9a088077..096c5191b1 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -13,7 +13,7 @@ from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy.utils import PrismaClient, hash_token -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import StandardLoggingMCPToolCall, StandardLoggingPayload from litellm.utils import get_end_user_id_for_cost_tracking @@ -38,6 +38,7 @@ def _get_spend_logs_metadata( metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None, batch_models: Optional[List[str]] = None, + mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None, ) -> SpendLogsMetadata: if metadata is None: return SpendLogsMetadata( @@ -55,6 +56,7 @@ def _get_spend_logs_metadata( error_information=None, proxy_server_request=None, batch_models=None, + mcp_tool_call_metadata=None, ) verbose_proxy_logger.debug( "getting payload for SpendLogs, available keys in metadata: " @@ -71,6 +73,7 @@ def _get_spend_logs_metadata( ) clean_metadata["applied_guardrails"] = applied_guardrails clean_metadata["batch_models"] = batch_models + clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata return clean_metadata @@ -200,6 +203,11 @@ def get_logging_payload( # noqa: PLR0915 if standard_logging_payload is not None else None ), + mcp_tool_call_metadata=( + standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None) + if standard_logging_payload is not None + else None + ), ) special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py new file mode 100644 index 0000000000..aecd11aa1a --- /dev/null +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING, Optional + +from pydantic import BaseModel, ConfigDict +from typing_extensions import TypedDict + + +class MCPInfo(TypedDict, total=False): + server_name: str + logo_url: Optional[str] + + +class MCPSSEServer(BaseModel): + name: str + url: str + mcp_info: Optional[MCPInfo] = None + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 7f84a41cd5..8716779d1f 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1644,6 +1644,33 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): user_api_key_end_user_id: Optional[str] +class StandardLoggingMCPToolCall(TypedDict, total=False): + name: str + """ + Name of the tool to call + """ + arguments: dict + """ + Arguments to pass to the tool + """ + result: dict + """ + Result of the tool call + """ + + mcp_server_name: Optional[str] + """ + Name of the MCP server that the tool call was made to + """ + + mcp_server_logo_url: Optional[str] + """ + Optional logo URL of the MCP server that the tool call was made to + + (this is to render the logo on the logs page on litellm ui) + """ + + class StandardBuiltInToolsParams(TypedDict, total=False): """ Standard built-in OpenAItools parameters @@ -1674,6 +1701,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): requester_ip_address: Optional[str] requester_metadata: Optional[dict] prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata] + mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] applied_guardrails: Optional[List[str]] diff --git a/tests/litellm/experimental_mcp_client/test_tools.py b/tests/litellm/experimental_mcp_client/test_tools.py index 7089d83217..ec430ecc9b 100644 --- a/tests/litellm/experimental_mcp_client/test_tools.py +++ b/tests/litellm/experimental_mcp_client/test_tools.py @@ -19,11 +19,11 @@ from mcp.types import Tool as MCPTool from litellm.experimental_mcp_client.tools import ( _get_function_arguments, - _transform_openai_tool_call_to_mcp_tool_call_request, call_mcp_tool, call_openai_tool, load_mcp_tools, transform_mcp_tool_to_openai_tool, + transform_openai_tool_call_request_to_mcp_tool_call_request, ) @@ -76,11 +76,11 @@ def test_transform_mcp_tool_to_openai_tool(mock_mcp_tool): } -def test_transform_openai_tool_call_to_mcp_tool_call_request(mock_mcp_tool): +def testtransform_openai_tool_call_request_to_mcp_tool_call_request(mock_mcp_tool): openai_tool = { "function": {"name": "test_tool", "arguments": json.dumps({"test": "value"})} } - mcp_tool_call_request = _transform_openai_tool_call_to_mcp_tool_call_request( + mcp_tool_call_request = transform_openai_tool_call_request_to_mcp_tool_call_request( openai_tool ) assert mcp_tool_call_request.name == "test_tool" diff --git a/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py b/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py index c1e94138a3..78da1b5dda 100644 --- a/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py +++ b/tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py @@ -457,7 +457,7 @@ class TestSpendLogsPayload: "model": "gpt-4o", "user": "", "team_id": "", - "metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}', + "metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}', "cache_key": "Cache OFF", "spend": 0.00022500000000000002, "total_tokens": 30, @@ -555,7 +555,7 @@ class TestSpendLogsPayload: "model": "claude-3-7-sonnet-20250219", "user": "", "team_id": "", - "metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}', + "metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}', "cache_key": "Cache OFF", "spend": 0.01383, "total_tokens": 2598, @@ -651,7 +651,7 @@ class TestSpendLogsPayload: "model": "claude-3-7-sonnet-20250219", "user": "", "team_id": "", - "metadata": '{"applied_guardrails": [], "batch_models": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}', + "metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": {"audio_tokens": null, "cached_tokens": 0, "text_tokens": null, "image_tokens": null}, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}}', "cache_key": "Cache OFF", "spend": 0.01383, "total_tokens": 2598, diff --git a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json index a4c0f3f58b..656cb6d589 100644 --- a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json +++ b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json @@ -9,7 +9,7 @@ "model": "gpt-4o", "user": "", "team_id": "", - "metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}", + "metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"mcp_tool_call_metadata\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}", "cache_key": "Cache OFF", "spend": 0.00022500000000000002, "total_tokens": 30, diff --git a/tests/logging_callback_tests/gcs_pub_sub_body/standard_logging_payload.json b/tests/logging_callback_tests/gcs_pub_sub_body/standard_logging_payload.json index eb57387120..1dc72b704f 100644 --- a/tests/logging_callback_tests/gcs_pub_sub_body/standard_logging_payload.json +++ b/tests/logging_callback_tests/gcs_pub_sub_body/standard_logging_payload.json @@ -25,7 +25,8 @@ "requester_metadata": null, "user_api_key_end_user_id": null, "prompt_management_metadata": null, - "applied_guardrails": [] + "applied_guardrails": [], + "mcp_tool_call_metadata": null }, "cache_key": null, "response_cost": 0.00022500000000000002, diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index aeec20be23..2e102ec46c 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -274,6 +274,7 @@ def validate_redacted_message_span_attributes(span): "metadata.user_api_key_end_user_id", "metadata.user_api_key_user_email", "metadata.applied_guardrails", + "metadata.mcp_tool_call_metadata", ] _all_attributes = set( diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py new file mode 100644 index 0000000000..2cf9193871 --- /dev/null +++ b/tests/mcp_tests/test_mcp_server.py @@ -0,0 +1,35 @@ +# Create server parameters for stdio connection +import os +import sys +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + MCPSSEServer, +) + + +mcp_server_manager = MCPServerManager() + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="Local only test") +async def test_mcp_server_manager(): + mcp_server_manager.load_servers_from_config( + { + "zapier_mcp_server": { + "url": os.environ.get("ZAPIER_MCP_SERVER_URL"), + } + } + ) + tools = await mcp_server_manager.list_tools() + print("TOOLS FROM MCP SERVER MANAGER== ", tools) + + result = await mcp_server_manager.call_tool( + name="gmail_send_email", arguments={"body": "Test"} + ) + print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result) diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index f480501b58..a4256b3f4b 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -32,6 +32,8 @@ import GuardrailsPanel from "@/components/guardrails"; import TransformRequestPanel from "@/components/transform_request"; import { fetchUserModels } from "@/components/create_key_button"; import { fetchTeams } from "@/components/common_components/fetch_teams"; +import MCPToolsViewer from "@/components/mcp_tools"; + function getCookie(name: string) { const cookieValue = document.cookie .split("; ") @@ -347,6 +349,12 @@ export default function CreateKeyPage() { accessToken={accessToken} allTeams={teams as Team[] ?? []} /> + ) : page == "mcp-tools" ? ( + ) : page == "new_usage" ? ( , roles: all_admin_roles }, { key: "11", page: "guardrails", label: "Guardrails", icon: , roles: all_admin_roles }, { key: "12", page: "new_usage", label: "New Usage", icon: , roles: all_admin_roles }, + { key: "18", page: "mcp-tools", label: "MCP Tools", icon: , roles: all_admin_roles }, ] }, { diff --git a/ui/litellm-dashboard/src/components/mcp_tools/code-example.tsx b/ui/litellm-dashboard/src/components/mcp_tools/code-example.tsx new file mode 100644 index 0000000000..5e0b170530 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/code-example.tsx @@ -0,0 +1,92 @@ +import React from 'react'; + +const codeString = `import asyncio +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionUserMessageParam +from mcp import ClientSession +from mcp.client.sse import sse_client +from litellm.experimental_mcp_client.tools import ( + transform_mcp_tool_to_openai_tool, + transform_openai_tool_call_request_to_mcp_tool_call_request, +) + +async def main(): + # Initialize clients + client = AsyncOpenAI( + api_key="sk-1234", + base_url="http://localhost:4000" + ) + + # Connect to MCP + async with sse_client("http://localhost:4000/mcp/") as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + mcp_tools = await session.list_tools() + print("List of MCP tools for MCP server:", mcp_tools.tools) + + # Create message + messages = [ + ChatCompletionUserMessageParam( + content="Send an email about LiteLLM supporting MCP", + role="user" + ) + ] + + # Request with tools + response = await client.chat.completions.create( + model="gpt-4o", + messages=messages, + tools=[transform_mcp_tool_to_openai_tool(tool) for tool in mcp_tools.tools], + tool_choice="auto" + ) + + # Handle tool call + if response.choices[0].message.tool_calls: + tool_call = response.choices[0].message.tool_calls[0] + if tool_call: + # Convert format + mcp_call = transform_openai_tool_call_request_to_mcp_tool_call_request( + openai_tool=tool_call.model_dump() + ) + + # Execute tool + result = await session.call_tool( + name=mcp_call.name, + arguments=mcp_call.arguments + ) + + print("Result:", result) + +# Run it +asyncio.run(main())`; + +export const CodeExample: React.FC = () => { + return ( +
+
+

Using MCP Tools

+
+
+
+
+
Python integration
+
+ +
+ +
+
+            {codeString}
+          
+
+
+
+ ); +}; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/mcp_tools/columns.tsx b/ui/litellm-dashboard/src/components/mcp_tools/columns.tsx new file mode 100644 index 0000000000..b99ecbd9f0 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/columns.tsx @@ -0,0 +1,309 @@ +import React from "react"; +import { ColumnDef } from "@tanstack/react-table"; +import { MCPTool, InputSchema } from "./types"; +import { Button } from "@tremor/react" + +export const columns: ColumnDef[] = [ + { + accessorKey: "mcp_info.server_name", + header: "Provider", + cell: ({ row }) => { + const serverName = row.original.mcp_info.server_name; + const logoUrl = row.original.mcp_info.logo_url; + + return ( +
+ {logoUrl && ( + {`${serverName} + )} + {serverName} +
+ ); + }, + }, + { + accessorKey: "name", + header: "Tool Name", + cell: ({ row }) => { + const name = row.getValue("name") as string; + return ( +
+ {name} +
+ ); + }, + }, + { + accessorKey: "description", + header: "Description", + cell: ({ row }) => { + const description = row.getValue("description") as string; + return ( +
+ {description} +
+ ); + }, + }, + { + id: "actions", + header: "Actions", + cell: ({ row }) => { + const tool = row.original; + + return ( +
+ +
+ ); + }, + }, +]; + +// Tool Panel component to display when a tool is selected +export function ToolTestPanel({ + tool, + onSubmit, + isLoading, + result, + error, + onClose +}: { + tool: MCPTool; + onSubmit: (args: Record) => void; + isLoading: boolean; + result: any | null; + error: Error | null; + onClose: () => void; +}) { + const [formState, setFormState] = React.useState>({}); + + // Create a placeholder schema if we only have the "tool_input_schema" string + const schema: InputSchema = React.useMemo(() => { + if (typeof tool.inputSchema === 'string') { + // Default schema with a single text field + return { + type: "object", + properties: { + input: { + type: "string", + description: "Input for this tool" + } + }, + required: ["input"] + }; + } + return tool.inputSchema as InputSchema; + }, [tool.inputSchema]); + + const handleInputChange = (key: string, value: any) => { + setFormState(prev => ({ + ...prev, + [key]: value + })); + }; + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + onSubmit(formState); + }; + + return ( +
+
+
+

Test Tool: {tool.name}

+

{tool.description}

+

Provider: {tool.mcp_info.server_name}

+
+ +
+ +
+ {/* Form Section */} +
+

Input Parameters

+
+ {typeof tool.inputSchema === 'string' ? ( +
+

This tool uses a dynamic input schema.

+
+ + handleInputChange("input", e.target.value)} + required + className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm" + /> +
+
+ ) : ( + Object.entries(schema.properties).map(([key, prop]) => ( +
+ + {prop.description && ( +

{prop.description}

+ )} + + {/* Render appropriate input based on type */} + {prop.type === "string" && ( + handleInputChange(key, e.target.value)} + required={schema.required?.includes(key)} + className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm" + /> + )} + + {prop.type === "number" && ( + handleInputChange(key, parseFloat(e.target.value))} + required={schema.required?.includes(key)} + className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm" + /> + )} + + {prop.type === "boolean" && ( +
+ handleInputChange(key, e.target.checked)} + className="h-4 w-4 text-blue-600 focus:ring-blue-500 border-gray-300 rounded" + /> + Enable +
+ )} +
+ )) + )} + +
+ +
+
+
+ + {/* Result Section */} +
+

Result

+ + {isLoading && ( +
+
+
+ )} + + {error && ( +
+

Error

+
{error.message}
+
+ )} + + {result && !isLoading && !error && ( +
+ {result.map((content: any, idx: number) => ( +
+ {content.type === "text" && ( +
+

{content.text}

+
+ )} + + {content.type === "image" && content.url && ( +
+ Tool result +
+ )} + + {content.type === "embedded_resource" && ( +
+

Embedded Resource

+

Type: {content.resource_type}

+ {content.url && ( + + View Resource + + )} +
+ )} +
+ ))} + +
+
+ Raw JSON Response +
+                    {JSON.stringify(result, null, 2)}
+                  
+
+
+
+ )} + + {!result && !isLoading && !error && ( +
+

The result will appear here after you call the tool.

+
+ )} +
+
+
+ ); +} \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/mcp_tools/index.tsx b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx new file mode 100644 index 0000000000..ae3d4fac62 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx @@ -0,0 +1,172 @@ +import React, { useState } from 'react'; +import { useQuery, useMutation } from '@tanstack/react-query'; +import { DataTable } from '../view_logs/table'; +import { columns, ToolTestPanel } from './columns'; +import { MCPTool, MCPToolsViewerProps, CallMCPToolResponse } from './types'; +import { listMCPTools, callMCPTool } from '../networking'; + +// Wrapper to handle the type mismatch between MCPTool and DataTable's expected type +function DataTableWrapper({ + columns, + data, + isLoading, +}: { + columns: any; + data: MCPTool[]; + isLoading: boolean; +}) { + // Create a dummy renderSubComponent and getRowCanExpand function + const renderSubComponent = () =>
; + const getRowCanExpand = () => false; + + return ( + + ); +} + +export default function MCPToolsViewer({ + accessToken, + userRole, + userID, +}: MCPToolsViewerProps) { + const [searchTerm, setSearchTerm] = useState(''); + const [selectedTool, setSelectedTool] = useState(null); + const [toolResult, setToolResult] = useState(null); + const [toolError, setToolError] = useState(null); + + // Query to fetch MCP tools + const { data: mcpTools, isLoading: isLoadingTools } = useQuery({ + queryKey: ['mcpTools'], + queryFn: () => { + if (!accessToken) throw new Error('Access Token required'); + return listMCPTools(accessToken); + }, + enabled: !!accessToken, + }); + + // Mutation for calling a tool + const { mutate: executeTool, isPending: isCallingTool } = useMutation({ + mutationFn: (args: { tool: MCPTool; arguments: Record }) => { + if (!accessToken) throw new Error('Access Token required'); + return callMCPTool( + accessToken, + args.tool.name, + args.arguments + ); + }, + onSuccess: (data) => { + setToolResult(data); + setToolError(null); + }, + onError: (error: Error) => { + setToolError(error); + setToolResult(null); + }, + }); + + // Add onToolSelect handler to each tool + const toolsData = React.useMemo(() => { + if (!mcpTools) return []; + + return mcpTools.map((tool: MCPTool) => ({ + ...tool, + onToolSelect: (tool: MCPTool) => { + setSelectedTool(tool); + setToolResult(null); + setToolError(null); + } + })); + }, [mcpTools]); + + // Filter tools based on search term + const filteredTools = React.useMemo(() => { + return toolsData.filter((tool: MCPTool) => { + const searchLower = searchTerm.toLowerCase(); + return ( + tool.name.toLowerCase().includes(searchLower) || + tool.description.toLowerCase().includes(searchLower) || + tool.mcp_info.server_name.toLowerCase().includes(searchLower) + ); + }); + }, [toolsData, searchTerm]); + + // Handle tool call submission + const handleToolSubmit = (args: Record) => { + if (!selectedTool) return; + + executeTool({ + tool: selectedTool, + arguments: args, + }); + }; + + if (!accessToken || !userRole || !userID) { + return
Missing required authentication parameters.
; + } + + return ( +
+
+

MCP Tools

+
+ +
+
+
+
+ setSearchTerm(e.target.value)} + /> + + + +
+
+ {filteredTools.length} tool{filteredTools.length !== 1 ? "s" : ""} available +
+
+
+ + +
+ + {/* Tool Test Panel - Show when a tool is selected */} + {selectedTool && ( +
+ setSelectedTool(null)} + /> +
+ )} +
+ ); +} \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx new file mode 100644 index 0000000000..7bbb76fa23 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx @@ -0,0 +1,71 @@ +// Define the structure for tool input schema properties +export interface InputSchemaProperty { + type: string; + description?: string; + } + + // Define the structure for the input schema of a tool + export interface InputSchema { + type: "object"; + properties: Record; + required?: string[]; + } + + // Define MCP provider info + export interface MCPInfo { + server_name: string; + logo_url?: string; + } + + // Define the structure for a single MCP tool + export interface MCPTool { + name: string; + description: string; + inputSchema: InputSchema | string; // API returns string "tool_input_schema" or the actual schema + mcp_info: MCPInfo; + // Function to select a tool (added in the component) + onToolSelect?: (tool: MCPTool) => void; + } + + // Define the response structure for the listMCPTools endpoint - now a flat array + export type ListMCPToolsResponse = MCPTool[]; + + // Define the argument structure for calling an MCP tool + export interface CallMCPToolArgs { + name: string; + arguments: Record | null; + server_name?: string; // Now using server_name from mcp_info + } + + // Define the possible content types in the response + export interface MCPTextContent { + type: "text"; + text: string; + annotations?: any; + } + + export interface MCPImageContent { + type: "image"; + url?: string; + data?: string; + } + + export interface MCPEmbeddedResource { + type: "embedded_resource"; + resource_type?: string; + url?: string; + data?: any; + } + + // Define the union type for the content array in the response + export type MCPContent = MCPTextContent | MCPImageContent | MCPEmbeddedResource; + + // Define the response structure for the callMCPTool endpoint + export type CallMCPToolResponse = MCPContent[]; + + // Props for the main component + export interface MCPToolsViewerProps { + accessToken: string | null; + userRole: string | null; + userID: string | null; + } \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 83d258e532..9190319cc9 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -4084,3 +4084,73 @@ export const updateInternalUserSettings = async (accessToken: string, settings: throw error; } }; + + +export const listMCPTools = async (accessToken: string) => { + try { + // Construct base URL + let url = proxyBaseUrl + ? `${proxyBaseUrl}/mcp/tools/list` + : `/mcp/tools/list`; + + console.log("Fetching MCP tools from:", url); + + const response = await fetch(url, { + method: "GET", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error("Network response was not ok"); + } + + const data = await response.json(); + console.log("Fetched MCP tools:", data); + return data; + } catch (error) { + console.error("Failed to fetch MCP tools:", error); + throw error; + } +}; + + +export const callMCPTool = async (accessToken: string, toolName: string, toolArguments: Record) => { + try { + // Construct base URL + let url = proxyBaseUrl + ? `${proxyBaseUrl}/mcp/tools/call` + : `/mcp/tools/call`; + + console.log("Calling MCP tool:", toolName, "with arguments:", toolArguments); + + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + name: toolName, + arguments: toolArguments, + }), + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error("Network response was not ok"); + } + + const data = await response.json(); + console.log("MCP tool call response:", data); + return data; + } catch (error) { + console.error("Failed to call MCP tool:", error); + throw error; + } +}; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/view_logs/columns.tsx b/ui/litellm-dashboard/src/components/view_logs/columns.tsx index 98f5bfbc9d..2732fdaa77 100644 --- a/ui/litellm-dashboard/src/components/view_logs/columns.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/columns.tsx @@ -8,6 +8,19 @@ import { Tooltip } from "antd"; import { TimeCell } from "./time_cell"; import { Button } from "@tremor/react"; +// Helper to get the appropriate logo URL +const getLogoUrl = ( + row: LogEntry, + provider: string +) => { + // Check if mcp_tool_call_metadata exists and contains mcp_server_logo_url + if (row.metadata?.mcp_tool_call_metadata?.mcp_server_logo_url) { + return row.metadata.mcp_tool_call_metadata.mcp_server_logo_url; + } + // Fall back to default provider logo + return provider ? getProviderLogoAndName(provider).logo : ''; +}; + export type LogEntry = { request_id: string; api_key: string; @@ -177,7 +190,7 @@ export const columns: ColumnDef[] = [
{provider && ( {