diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 7ace983f8..fbce67d78 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -6548,6 +6548,83 @@
"model_context_protocol"
]
},
+ "MCPConfig": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/MCPInlineConfig"
+ },
+ {
+ "$ref": "#/components/schemas/MCPRemoteConfig"
+ }
+ ]
+ },
+ "MCPInlineConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "inline",
+ "default": "inline"
+ },
+ "command": {
+ "type": "string"
+ },
+ "args": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "env": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "command"
+ ]
+ },
+ "MCPRemoteConfig": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "remote",
+ "default": "remote"
+ },
+ "mcp_endpoint": {
+ "$ref": "#/components/schemas/URL"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "mcp_endpoint"
+ ]
+ },
"ToolGroup": {
"type": "object",
"properties": {
@@ -6565,8 +6642,8 @@
"const": "tool_group",
"default": "tool_group"
},
- "mcp_endpoint": {
- "$ref": "#/components/schemas/URL"
+ "mcp_config": {
+ "$ref": "#/components/schemas/MCPConfig"
},
"args": {
"type": "object",
@@ -6916,8 +6993,8 @@
"ListRuntimeToolsRequest": {
"type": "object",
"properties": {
- "mcp_endpoint": {
- "$ref": "#/components/schemas/URL"
+ "mcp_config": {
+ "$ref": "#/components/schemas/MCPConfig"
}
},
"additionalProperties": false
@@ -8022,8 +8099,8 @@
"provider_id": {
"type": "string"
},
- "mcp_endpoint": {
- "$ref": "#/components/schemas/URL"
+ "mcp_config": {
+ "$ref": "#/components/schemas/MCPConfig"
},
"args": {
"type": "object",
@@ -8932,6 +9009,18 @@
"name": "LoraFinetuningConfig",
"description": ""
},
+ {
+ "name": "MCPConfig",
+ "description": ""
+ },
+ {
+ "name": "MCPInlineConfig",
+ "description": ""
+ },
+ {
+ "name": "MCPRemoteConfig",
+ "description": ""
+ },
{
"name": "Memory"
},
@@ -9437,6 +9526,9 @@
"LogEventRequest",
"LogSeverity",
"LoraFinetuningConfig",
+ "MCPConfig",
+ "MCPInlineConfig",
+ "MCPRemoteConfig",
"MemoryBankDocument",
"MemoryRetrievalStep",
"Message",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index a2f6bc005..578144e51 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -1125,8 +1125,8 @@ components:
ListRuntimeToolsRequest:
additionalProperties: false
properties:
- mcp_endpoint:
- $ref: '#/components/schemas/URL'
+ mcp_config:
+ $ref: '#/components/schemas/MCPConfig'
type: object
LogEventRequest:
additionalProperties: false
@@ -1184,6 +1184,50 @@ components:
- rank
- alpha
type: object
+ MCPConfig:
+ oneOf:
+ - $ref: '#/components/schemas/MCPInlineConfig'
+ - $ref: '#/components/schemas/MCPRemoteConfig'
+ MCPInlineConfig:
+ additionalProperties: false
+ properties:
+ args:
+ items:
+ type: string
+ type: array
+ command:
+ type: string
+ env:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ type:
+ const: inline
+ default: inline
+ type: string
+ required:
+ - type
+ - command
+ type: object
+ MCPRemoteConfig:
+ additionalProperties: false
+ properties:
+ mcp_endpoint:
+ $ref: '#/components/schemas/URL'
+ type:
+ const: remote
+ default: remote
+ type: string
+ required:
+ - type
+ - mcp_endpoint
+ type: object
MemoryBankDocument:
additionalProperties: false
properties:
@@ -1897,8 +1941,8 @@ components:
- type: array
- type: object
type: object
- mcp_endpoint:
- $ref: '#/components/schemas/URL'
+ mcp_config:
+ $ref: '#/components/schemas/MCPConfig'
provider_id:
type: string
toolgroup_id:
@@ -2773,8 +2817,8 @@ components:
type: object
identifier:
type: string
- mcp_endpoint:
- $ref: '#/components/schemas/URL'
+ mcp_config:
+ $ref: '#/components/schemas/MCPConfig'
provider_id:
type: string
provider_resource_id:
@@ -5615,6 +5659,14 @@ tags:
- description:
name: LoraFinetuningConfig
+- description:
+ name: MCPConfig
+- description:
+ name: MCPInlineConfig
+- description:
+ name: MCPRemoteConfig
- name: Memory
- description:
@@ -5982,6 +6034,9 @@ x-tagGroups:
- LogEventRequest
- LogSeverity
- LoraFinetuningConfig
+ - MCPConfig
+ - MCPInlineConfig
+ - MCPRemoteConfig
- MemoryBankDocument
- MemoryRetrievalStep
- Message
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index e430ec46d..773c4a3f0 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional
+from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat
-from llama_models.schema_utils import json_schema_type, webmethod
+from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@@ -57,18 +57,35 @@ class ToolDef(BaseModel):
)
+@json_schema_type
+class MCPInlineConfig(BaseModel):
+ type: Literal["inline"] = "inline"
+ command: str
+ args: Optional[List[str]] = None
+ env: Optional[Dict[str, Any]] = None
+
+
+@json_schema_type
+class MCPRemoteConfig(BaseModel):
+ type: Literal["remote"] = "remote"
+ mcp_endpoint: URL
+
+
+MCPConfig = register_schema(Union[MCPInlineConfig, MCPRemoteConfig], name="MCPConfig")
+
+
@json_schema_type
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
- mcp_endpoint: Optional[URL] = None
+ mcp_config: Optional[MCPConfig] = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
- mcp_endpoint: Optional[URL] = None
+ mcp_config: Optional[MCPConfig] = None
args: Optional[Dict[str, Any]] = None
@@ -92,7 +109,7 @@ class ToolGroups(Protocol):
self,
toolgroup_id: str,
provider_id: str,
- mcp_endpoint: Optional[URL] = None,
+ mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
@@ -131,7 +148,9 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 05d43ad4f..03d835026 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
-from llama_stack.apis.common.content_types import InterleavedContent, URL
+from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
@@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
-from llama_stack.apis.tools import ToolDef, ToolRuntime
+from llama_stack.apis.tools import MCPConfig, ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
@@ -418,8 +418,10 @@ class ToolRuntimeRouter(ToolRuntime):
)
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
- tool_group_id, mcp_endpoint
+ tool_group_id, mcp_config
)
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index d4cb708a2..335e63d9d 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
-from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
+from llama_stack.apis.tools import MCPConfig, Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
@@ -504,15 +504,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self,
toolgroup_id: str,
provider_id: str,
- mcp_endpoint: Optional[URL] = None,
+ mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
- toolgroup_id, mcp_endpoint
+ toolgroup_id, mcp_config
)
tool_host = (
- ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
+ ToolHost.model_context_protocol if mcp_config else ToolHost.distribution
)
for tool_def in tool_defs:
@@ -547,7 +547,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
- mcp_endpoint=mcp_endpoint,
+ mcp_config=mcp_config,
args=args,
)
)
diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
index 361c91a92..6d006af49 100644
--- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
+++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
@@ -9,8 +9,8 @@ import logging
import tempfile
from typing import Any, Dict, List, Optional
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
+ MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@@ -43,7 +43,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(
diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py
index fe6325abb..66dd831b4 100644
--- a/llama_stack/providers/inline/tool_runtime/memory/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py
@@ -10,11 +10,11 @@ import secrets
import string
from typing import Any, Dict, List, Optional
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
+ MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
@@ -52,7 +52,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(
diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py
new file mode 100644
index 000000000..7befa58e5
--- /dev/null
+++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from pydantic import BaseModel
+
+from .config import ModelContextProtocolConfig
+from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
+
+
+class ModelContextProtocolToolProviderDataValidator(BaseModel):
+ api_key: str
+
+
+async def get_provider_impl(config: ModelContextProtocolConfig, _deps):
+ impl = ModelContextProtocolToolRuntimeImpl(config)
+ await impl.initialize()
+ return impl
diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py
new file mode 100644
index 000000000..ffe4c9887
--- /dev/null
+++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/config.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from pydantic import BaseModel
+
+
+class ModelContextProtocolConfig(BaseModel):
+ pass
diff --git a/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py
new file mode 100644
index 000000000..9427a0bdd
--- /dev/null
+++ b/llama_stack/providers/inline/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -0,0 +1,98 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import json
+from typing import Any, Dict, List, Optional
+
+from mcp import ClientSession
+from mcp.client.stdio import stdio_client, StdioServerParameters
+from pydantic import TypeAdapter
+
+from llama_stack.apis.tools import (
+ MCPConfig,
+ ToolDef,
+ ToolInvocationResult,
+ ToolParameter,
+ ToolRuntime,
+)
+from llama_stack.providers.datatypes import ToolsProtocolPrivate
+
+from .config import ModelContextProtocolConfig
+
+
+class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
+ def __init__(self, config: ModelContextProtocolConfig):
+ self.config = config
+
+ async def initialize(self):
+ pass
+
+ async def list_runtime_tools(
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
+ ) -> List[ToolDef]:
+ if mcp_config is None:
+ raise ValueError("mcp_config is required")
+
+ tools = []
+ async with stdio_client(
+ StdioServerParameters(
+ command=mcp_config.command,
+ args=mcp_config.args,
+ env=mcp_config.env,
+ )
+ ) as streams:
+ async with ClientSession(*streams) as session:
+ await session.initialize()
+ tools_result = await session.list_tools()
+ for tool in tools_result.tools:
+ parameters = []
+ for param_name, param_schema in tool.inputSchema.get(
+ "properties", {}
+ ).items():
+ parameters.append(
+ ToolParameter(
+ name=param_name,
+ parameter_type=param_schema.get("type", "string"),
+ description=param_schema.get("description", ""),
+ )
+ )
+ tools.append(
+ ToolDef(
+ name=tool.name,
+ description=tool.description,
+ parameters=parameters,
+ metadata={
+ "mcp_config": mcp_config.model_dump_json(),
+ },
+ )
+ )
+ return tools
+
+ async def invoke_tool(
+ self, tool_name: str, args: Dict[str, Any]
+ ) -> ToolInvocationResult:
+ tool = await self.tool_store.get_tool(tool_name)
+ if tool.metadata is None or tool.metadata.get("mcp_config") is None:
+ raise ValueError(f"Tool {tool_name} does not have metadata")
+ mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
+ mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
+ async with stdio_client(
+ StdioServerParameters(
+ command=mcp_config.command,
+ args=mcp_config.args,
+ env=mcp_config.env,
+ )
+ ) as streams:
+ async with ClientSession(*streams) as session:
+ await session.initialize()
+ result = await session.call_tool(tool.identifier, arguments=args)
+
+ return ToolInvocationResult(
+ content="\n".join([result.model_dump_json() for result in result.content]),
+ error_code=1 if result.isError else 0,
+ )
diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py
index 40299edad..7de3f09f4 100644
--- a/llama_stack/providers/registry/tool_runtime.py
+++ b/llama_stack/providers/registry/tool_runtime.py
@@ -32,6 +32,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
),
+ InlineProviderSpec(
+ api=Api.tool_runtime,
+ provider_type="inline::model-context-protocol",
+ pip_packages=["mcp"],
+ module="llama_stack.providers.inline.tool_runtime.model_context_protocol",
+ config_class="llama_stack.providers.inline.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
+ ),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
index 259d02f1b..628bdd524 100644
--- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
+++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
@@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
+ MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(
diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
index a304167e9..93afcb023 100644
--- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
+++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import json
from typing import Any, Dict, List, Optional
-from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
+from pydantic import TypeAdapter
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
+ MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
@@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
- if mcp_endpoint is None:
- raise ValueError("mcp_endpoint is required")
+ if mcp_config is None:
+ raise ValueError("mcp_config is required")
tools = []
- async with sse_client(mcp_endpoint.uri) as streams:
+ async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description=tool.description,
parameters=parameters,
metadata={
- "endpoint": mcp_endpoint.uri,
+ "mcp_config": mcp_config,
},
)
)
@@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
- if tool.metadata is None or tool.metadata.get("endpoint") is None:
+ if tool.metadata is None or tool.metadata.get("mcp_config") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
- endpoint = tool.metadata.get("endpoint")
- if urlparse(endpoint).scheme not in ("http", "https"):
- raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
+ mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
+ mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
- async with sse_client(endpoint) as streams:
+ async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
index 1716f96e5..0fc3bfd31 100644
--- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
+++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
@@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
+ MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(
diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
index 8d0792ca0..182c3acba 100644
--- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
+++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
@@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests
-from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
+ MCPConfig,
Tool,
ToolDef,
ToolInvocationResult,
@@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl(
return provider_data.api_key
async def list_runtime_tools(
- self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
+ self,
+ tool_group_id: Optional[str] = None,
+ mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
return [
ToolDef(
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index a2ed687a4..a133ed3d3 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
+import os
from typing import Dict, List
from uuid import uuid4
@@ -324,3 +325,35 @@ def test_rag_agent(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_memory" in logs_str
+
+
+def test_mcp_agent(llama_stack_client, agent_config):
+ llama_stack_client.toolgroups.register(
+ toolgroup_id="brave-search",
+ provider_id="model-context-protocol",
+ mcp_config=dict(
+ type="inline",
+ command="/Users/dineshyv/homebrew/bin/npx",
+ args=["-y", "@modelcontextprotocol/server-brave-search"],
+ env={
+ "BRAVE_API_KEY": os.environ["BRAVE_SEARCH_API_KEY"],
+ "PATH": os.environ["PATH"],
+ },
+ ),
+ )
+ agent_config = {
+ **agent_config,
+ "toolgroups": [
+ "brave-search",
+ ],
+ }
+ agent = Agent(llama_stack_client, agent_config)
+ session_id = agent.create_session("test-session")
+ response = agent.create_turn(
+ messages=[{"role": "user", "content": "what won the NBA playoffs in 2024?"}],
+ session_id=session_id,
+ )
+ logs = [str(log) for log in EventLogger().log(response) if log is not None]
+ logs_str = "".join(logs)
+ assert "Tool:brave_web_search" in logs_str
+ assert "celtics" in logs_str.lower()