diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ace983f8..fbce67d78 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -6548,6 +6548,83 @@ "model_context_protocol" ] }, + "MCPConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/MCPInlineConfig" + }, + { + "$ref": "#/components/schemas/MCPRemoteConfig" + } + ] + }, + "MCPInlineConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "inline", + "default": "inline" + }, + "command": { + "type": "string" + }, + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "env": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "command" + ] + }, + "MCPRemoteConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "remote", + "default": "remote" + }, + "mcp_endpoint": { + "$ref": "#/components/schemas/URL" + } + }, + "additionalProperties": false, + "required": [ + "type", + "mcp_endpoint" + ] + }, "ToolGroup": { "type": "object", "properties": { @@ -6565,8 +6642,8 @@ "const": "tool_group", "default": "tool_group" }, - "mcp_endpoint": { - "$ref": "#/components/schemas/URL" + "mcp_config": { + "$ref": "#/components/schemas/MCPConfig" }, "args": { "type": "object", @@ -6916,8 +6993,8 @@ "ListRuntimeToolsRequest": { "type": "object", "properties": { - "mcp_endpoint": { - "$ref": "#/components/schemas/URL" + "mcp_config": { + "$ref": "#/components/schemas/MCPConfig" } }, "additionalProperties": false @@ -8022,8 +8099,8 @@ "provider_id": { "type": "string" }, - "mcp_endpoint": { - "$ref": "#/components/schemas/URL" + "mcp_config": { + "$ref": "#/components/schemas/MCPConfig" }, "args": { "type": "object", @@ -8932,6 +9009,18 @@ "name": "LoraFinetuningConfig", "description": "" }, + { + "name": "MCPConfig", + "description": "" + }, + { + "name": "MCPInlineConfig", + "description": "" + }, + { + "name": "MCPRemoteConfig", + "description": "" + }, { "name": "Memory" }, @@ -9437,6 +9526,9 @@ "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", + "MCPConfig", + "MCPInlineConfig", + "MCPRemoteConfig", "MemoryBankDocument", "MemoryRetrievalStep", "Message", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index a2f6bc005..578144e51 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1125,8 +1125,8 @@ components: ListRuntimeToolsRequest: additionalProperties: false properties: - mcp_endpoint: - $ref: '#/components/schemas/URL' + mcp_config: + $ref: '#/components/schemas/MCPConfig' type: object LogEventRequest: additionalProperties: false @@ -1184,6 +1184,50 @@ components: - rank - alpha type: object + MCPConfig: + oneOf: + - $ref: '#/components/schemas/MCPInlineConfig' + - $ref: '#/components/schemas/MCPRemoteConfig' + MCPInlineConfig: + additionalProperties: false + properties: + args: + items: + type: string + type: array + command: + type: string + env: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: inline + default: inline + type: string + required: + - type + - command + type: object + MCPRemoteConfig: + additionalProperties: false + properties: + mcp_endpoint: + $ref: '#/components/schemas/URL' + type: + const: remote + default: remote + type: string + required: + - type + - mcp_endpoint + type: object MemoryBankDocument: additionalProperties: false properties: @@ -1897,8 +1941,8 @@ components: - type: array - type: object type: object - mcp_endpoint: - $ref: '#/components/schemas/URL' + mcp_config: + $ref: '#/components/schemas/MCPConfig' provider_id: type: string toolgroup_id: @@ -2773,8 +2817,8 @@ components: type: object identifier: type: string - mcp_endpoint: - $ref: '#/components/schemas/URL' + mcp_config: + $ref: '#/components/schemas/MCPConfig' provider_id: type: string provider_resource_id: @@ -5615,6 +5659,14 @@ tags: - description: name: LoraFinetuningConfig +- description: + name: MCPConfig +- description: + name: MCPInlineConfig +- description: + name: MCPRemoteConfig - name: Memory - description: @@ -5982,6 +6034,9 @@ x-tagGroups: - LogEventRequest - LogSeverity - LoraFinetuningConfig + - MCPConfig + - MCPInlineConfig + - MCPRemoteConfig - MemoryBankDocument - MemoryRetrievalStep - Message diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index e430ec46d..773c4a3f0 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -5,10 +5,10 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from llama_models.llama3.api.datatypes import ToolPromptFormat -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -57,18 +57,35 @@ class ToolDef(BaseModel): ) +@json_schema_type +class MCPInlineConfig(BaseModel): + type: Literal["inline"] = "inline" + command: str + args: Optional[List[str]] = None + env: Optional[Dict[str, Any]] = None + + +@json_schema_type +class MCPRemoteConfig(BaseModel): + type: Literal["remote"] = "remote" + mcp_endpoint: URL + + +MCPConfig = register_schema(Union[MCPInlineConfig, MCPRemoteConfig], name="MCPConfig") + + @json_schema_type class ToolGroupInput(BaseModel): toolgroup_id: str provider_id: str args: Optional[Dict[str, Any]] = None - mcp_endpoint: Optional[URL] = None + mcp_config: Optional[MCPConfig] = None @json_schema_type class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value - mcp_endpoint: Optional[URL] = None + mcp_config: Optional[MCPConfig] = None args: Optional[Dict[str, Any]] = None @@ -92,7 +109,7 @@ class ToolGroups(Protocol): self, toolgroup_id: str, provider_id: str, - mcp_endpoint: Optional[URL] = None, + mcp_config: Optional[MCPConfig] = None, args: Optional[Dict[str, Any]] = None, ) -> None: """Register a tool group""" @@ -131,7 +148,9 @@ class ToolRuntime(Protocol): # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: ... @webmethod(route="/tool-runtime/invoke", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 05d43ad4f..03d835026 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( AppEvalTaskConfig, @@ -38,7 +38,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import ToolDef, ToolRuntime +from llama_stack.apis.tools import MCPConfig, ToolDef, ToolRuntime from llama_stack.providers.datatypes import RoutingTable @@ -418,8 +418,10 @@ class ToolRuntimeRouter(ToolRuntime): ) async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return await self.routing_table.get_provider_impl(tool_group_id).list_tools( - tool_group_id, mcp_endpoint + tool_group_id, mcp_config ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d4cb708a2..335e63d9d 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFunctions, ) from llama_stack.apis.shields import Shield, Shields -from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost +from llama_stack.apis.tools import MCPConfig, Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.distribution.datatypes import ( RoutableObject, RoutableObjectWithProvider, @@ -504,15 +504,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): self, toolgroup_id: str, provider_id: str, - mcp_endpoint: Optional[URL] = None, + mcp_config: Optional[MCPConfig] = None, args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( - toolgroup_id, mcp_endpoint + toolgroup_id, mcp_config ) tool_host = ( - ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ToolHost.model_context_protocol if mcp_config else ToolHost.distribution ) for tool_def in tool_defs: @@ -547,7 +547,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): identifier=toolgroup_id, provider_id=provider_id, provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, + mcp_config=mcp_config, args=args, ) ) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 361c91a92..6d006af49 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,8 +9,8 @@ import logging import tempfile from typing import Any, Dict, List, Optional -from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( + MCPConfig, Tool, ToolDef, ToolInvocationResult, @@ -43,7 +43,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): return async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return [ ToolDef( diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index fe6325abb..66dd831b4 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,11 +10,11 @@ import secrets import string from typing import Any, Dict, List, Optional -from llama_stack.apis.common.content_types import URL from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.memory import Memory, QueryDocumentsResponse from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.tools import ( + MCPConfig, ToolDef, ToolInvocationResult, ToolParameter, @@ -52,7 +52,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): pass async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return [ ToolDef( diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py new file mode 100644 index 000000000..7befa58e5 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import ModelContextProtocolConfig +from .model_context_protocol import ModelContextProtocolToolRuntimeImpl + + +class ModelContextProtocolToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: ModelContextProtocolConfig, _deps): + impl = ModelContextProtocolToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py new file mode 100644 index 000000000..ffe4c9887 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class ModelContextProtocolConfig(BaseModel): + pass diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py new file mode 100644 index 000000000..9427a0bdd --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional + +from mcp import ClientSession +from mcp.client.stdio import stdio_client, StdioServerParameters +from pydantic import TypeAdapter + +from llama_stack.apis.tools import ( + MCPConfig, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import ModelContextProtocolConfig + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: ModelContextProtocolConfig): + self.config = config + + async def initialize(self): + pass + + async def list_runtime_tools( + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, + ) -> List[ToolDef]: + if mcp_config is None: + raise ValueError("mcp_config is required") + + tools = [] + async with stdio_client( + StdioServerParameters( + command=mcp_config.command, + args=mcp_config.args, + env=mcp_config.env, + ) + ) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get( + "properties", {} + ).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "mcp_config": mcp_config.model_dump_json(), + }, + ) + ) + return tools + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + tool = await self.tool_store.get_tool(tool_name) + if tool.metadata is None or tool.metadata.get("mcp_config") is None: + raise ValueError(f"Tool {tool_name} does not have metadata") + mcp_config_dict = json.loads(tool.metadata.get("mcp_config")) + mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict) + async with stdio_client( + StdioServerParameters( + command=mcp_config.command, + args=mcp_config.args, + env=mcp_config.env, + ) + ) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.call_tool(tool.identifier, arguments=args) + + return ToolInvocationResult( + content="\n".join([result.model_dump_json() for result in result.content]), + error_code=1 if result.isError else 0, + ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 40299edad..7de3f09f4 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -32,6 +32,13 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.tool_runtime.code_interpreter", config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig", ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::model-context-protocol", + pip_packages=["mcp"], + module="llama_stack.providers.inline.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.inline.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 259d02f1b..628bdd524 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional import requests from llama_models.llama3.api.datatypes import BuiltinTool -from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( + MCPConfig, Tool, ToolDef, ToolInvocationResult, @@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl( return provider_data.api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return [ ToolDef( diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a304167e9..93afcb023 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Any, Dict, List, Optional -from urllib.parse import urlparse from mcp import ClientSession from mcp.client.sse import sse_client +from pydantic import TypeAdapter -from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( + MCPConfig, ToolDef, ToolInvocationResult, ToolParameter, @@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): pass async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: - if mcp_endpoint is None: - raise ValueError("mcp_endpoint is required") + if mcp_config is None: + raise ValueError("mcp_config is required") tools = [] - async with sse_client(mcp_endpoint.uri) as streams: + async with sse_client(mcp_config.mcp_endpoint.uri) as streams: async with ClientSession(*streams) as session: await session.initialize() tools_result = await session.list_tools() @@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): description=tool.description, parameters=parameters, metadata={ - "endpoint": mcp_endpoint.uri, + "mcp_config": mcp_config, }, ) ) @@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) - if tool.metadata is None or tool.metadata.get("endpoint") is None: + if tool.metadata is None or tool.metadata.get("mcp_config") is None: raise ValueError(f"Tool {tool_name} does not have metadata") - endpoint = tool.metadata.get("endpoint") - if urlparse(endpoint).scheme not in ("http", "https"): - raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") + mcp_config_dict = json.loads(tool.metadata.get("mcp_config")) + mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict) - async with sse_client(endpoint) as streams: + async with sse_client(mcp_config.mcp_endpoint.uri) as streams: async with ClientSession(*streams) as session: await session.initialize() result = await session.call_tool(tool.identifier, args) diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 1716f96e5..0fc3bfd31 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( + MCPConfig, Tool, ToolDef, ToolInvocationResult, @@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl( return provider_data.api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return [ ToolDef( diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 8d0792ca0..182c3acba 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( + MCPConfig, Tool, ToolDef, ToolInvocationResult, @@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl( return provider_data.api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, + tool_group_id: Optional[str] = None, + mcp_config: Optional[MCPConfig] = None, ) -> List[ToolDef]: return [ ToolDef( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a2ed687a4..a133ed3d3 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +import os from typing import Dict, List from uuid import uuid4 @@ -324,3 +325,35 @@ def test_rag_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_memory" in logs_str + + +def test_mcp_agent(llama_stack_client, agent_config): + llama_stack_client.toolgroups.register( + toolgroup_id="brave-search", + provider_id="model-context-protocol", + mcp_config=dict( + type="inline", + command="/Users/dineshyv/homebrew/bin/npx", + args=["-y", "@modelcontextprotocol/server-brave-search"], + env={ + "BRAVE_API_KEY": os.environ["BRAVE_SEARCH_API_KEY"], + "PATH": os.environ["PATH"], + }, + ), + ) + agent_config = { + **agent_config, + "toolgroups": [ + "brave-search", + ], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session("test-session") + response = agent.create_turn( + messages=[{"role": "user", "content": "what won the NBA playoffs in 2024?"}], + session_id=session_id, + ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert "Tool:brave_web_search" in logs_str + assert "celtics" in logs_str.lower()