add inline mcp provider

This commit is contained in:
Dinesh Yeduguru 2025-01-08 23:11:43 -08:00
parent ffc6bd4805
commit 2c265d803c
16 changed files with 398 additions and 49 deletions

View file

@ -6548,6 +6548,83 @@
"model_context_protocol" "model_context_protocol"
] ]
}, },
"MCPConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/MCPInlineConfig"
},
{
"$ref": "#/components/schemas/MCPRemoteConfig"
}
]
},
"MCPInlineConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "inline",
"default": "inline"
},
"command": {
"type": "string"
},
"args": {
"type": "array",
"items": {
"type": "string"
}
},
"env": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"command"
]
},
"MCPRemoteConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "remote",
"default": "remote"
},
"mcp_endpoint": {
"$ref": "#/components/schemas/URL"
}
},
"additionalProperties": false,
"required": [
"type",
"mcp_endpoint"
]
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -6565,8 +6642,8 @@
"const": "tool_group", "const": "tool_group",
"default": "tool_group" "default": "tool_group"
}, },
"mcp_endpoint": { "mcp_config": {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/MCPConfig"
}, },
"args": { "args": {
"type": "object", "type": "object",
@ -6916,8 +6993,8 @@
"ListRuntimeToolsRequest": { "ListRuntimeToolsRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"mcp_endpoint": { "mcp_config": {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/MCPConfig"
} }
}, },
"additionalProperties": false "additionalProperties": false
@ -8022,8 +8099,8 @@
"provider_id": { "provider_id": {
"type": "string" "type": "string"
}, },
"mcp_endpoint": { "mcp_config": {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/MCPConfig"
}, },
"args": { "args": {
"type": "object", "type": "object",
@ -8932,6 +9009,18 @@
"name": "LoraFinetuningConfig", "name": "LoraFinetuningConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LoraFinetuningConfig\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/LoraFinetuningConfig\" />"
}, },
{
"name": "MCPConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPConfig\" />"
},
{
"name": "MCPInlineConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPInlineConfig\" />"
},
{
"name": "MCPRemoteConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPRemoteConfig\" />"
},
{ {
"name": "Memory" "name": "Memory"
}, },
@ -9437,6 +9526,9 @@
"LogEventRequest", "LogEventRequest",
"LogSeverity", "LogSeverity",
"LoraFinetuningConfig", "LoraFinetuningConfig",
"MCPConfig",
"MCPInlineConfig",
"MCPRemoteConfig",
"MemoryBankDocument", "MemoryBankDocument",
"MemoryRetrievalStep", "MemoryRetrievalStep",
"Message", "Message",

View file

@ -1125,8 +1125,8 @@ components:
ListRuntimeToolsRequest: ListRuntimeToolsRequest:
additionalProperties: false additionalProperties: false
properties: properties:
mcp_endpoint: mcp_config:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/MCPConfig'
type: object type: object
LogEventRequest: LogEventRequest:
additionalProperties: false additionalProperties: false
@ -1184,6 +1184,50 @@ components:
- rank - rank
- alpha - alpha
type: object type: object
MCPConfig:
oneOf:
- $ref: '#/components/schemas/MCPInlineConfig'
- $ref: '#/components/schemas/MCPRemoteConfig'
MCPInlineConfig:
additionalProperties: false
properties:
args:
items:
type: string
type: array
command:
type: string
env:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: inline
default: inline
type: string
required:
- type
- command
type: object
MCPRemoteConfig:
additionalProperties: false
properties:
mcp_endpoint:
$ref: '#/components/schemas/URL'
type:
const: remote
default: remote
type: string
required:
- type
- mcp_endpoint
type: object
MemoryBankDocument: MemoryBankDocument:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1897,8 +1941,8 @@ components:
- type: array - type: array
- type: object - type: object
type: object type: object
mcp_endpoint: mcp_config:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/MCPConfig'
provider_id: provider_id:
type: string type: string
toolgroup_id: toolgroup_id:
@ -2773,8 +2817,8 @@ components:
type: object type: object
identifier: identifier:
type: string type: string
mcp_endpoint: mcp_config:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/MCPConfig'
provider_id: provider_id:
type: string type: string
provider_resource_id: provider_resource_id:
@ -5615,6 +5659,14 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/LoraFinetuningConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/LoraFinetuningConfig"
/> />
name: LoraFinetuningConfig name: LoraFinetuningConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPConfig" />
name: MCPConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPInlineConfig"
/>
name: MCPInlineConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPRemoteConfig"
/>
name: MCPRemoteConfig
- name: Memory - name: Memory
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument" - description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument"
/> />
@ -5982,6 +6034,9 @@ x-tagGroups:
- LogEventRequest - LogEventRequest
- LogSeverity - LogSeverity
- LoraFinetuningConfig - LoraFinetuningConfig
- MCPConfig
- MCPInlineConfig
- MCPRemoteConfig
- MemoryBankDocument - MemoryBankDocument
- MemoryRetrievalStep - MemoryRetrievalStep
- Message - Message

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
@ -57,18 +57,35 @@ class ToolDef(BaseModel):
) )
@json_schema_type
class MCPInlineConfig(BaseModel):
type: Literal["inline"] = "inline"
command: str
args: Optional[List[str]] = None
env: Optional[Dict[str, Any]] = None
@json_schema_type
class MCPRemoteConfig(BaseModel):
type: Literal["remote"] = "remote"
mcp_endpoint: URL
MCPConfig = register_schema(Union[MCPInlineConfig, MCPRemoteConfig], name="MCPConfig")
@json_schema_type @json_schema_type
class ToolGroupInput(BaseModel): class ToolGroupInput(BaseModel):
toolgroup_id: str toolgroup_id: str
provider_id: str provider_id: str
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None mcp_config: Optional[MCPConfig] = None
@json_schema_type @json_schema_type
class ToolGroup(Resource): class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None mcp_config: Optional[MCPConfig] = None
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
@ -92,7 +109,7 @@ class ToolGroups(Protocol):
self, self,
toolgroup_id: str, toolgroup_id: str,
provider_id: str, provider_id: str,
mcp_endpoint: Optional[URL] = None, mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group"""
@ -131,7 +148,9 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ... ) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
AppEvalTaskConfig, AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime from llama_stack.apis.tools import MCPConfig, ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -418,8 +418,10 @@ class ToolRuntimeRouter(ToolRuntime):
) )
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint tool_group_id, mcp_config
) )

View file

@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions, ScoringFunctions,
) )
from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.apis.tools import MCPConfig, Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
RoutableObject, RoutableObject,
RoutableObjectWithProvider, RoutableObjectWithProvider,
@ -504,15 +504,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self, self,
toolgroup_id: str, toolgroup_id: str,
provider_id: str, provider_id: str,
mcp_endpoint: Optional[URL] = None, mcp_config: Optional[MCPConfig] = None,
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint toolgroup_id, mcp_config
) )
tool_host = ( tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution ToolHost.model_context_protocol if mcp_config else ToolHost.distribution
) )
for tool_def in tool_defs: for tool_def in tool_defs:
@ -547,7 +547,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint, mcp_config=mcp_config,
args=args, args=args,
) )
) )

View file

@ -9,8 +9,8 @@ import logging
import tempfile import tempfile
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -43,7 +43,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return return
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return [ return [
ToolDef( ToolDef(

View file

@ -10,11 +10,11 @@ import secrets
import string import string
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
@ -52,7 +52,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass pass
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return [ return [
ToolDef( ToolDef(

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: ModelContextProtocolConfig, _deps):
impl = ModelContextProtocolToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
pass

View file

@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from mcp import ClientSession
from mcp.client.stdio import stdio_client, StdioServerParameters
from pydantic import TypeAdapter
from llama_stack.apis.tools import (
MCPConfig,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
self.config = config
async def initialize(self):
pass
async def list_runtime_tools(
self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]:
if mcp_config is None:
raise ValueError("mcp_config is required")
tools = []
async with stdio_client(
StdioServerParameters(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"mcp_config": mcp_config.model_dump_json(),
},
)
)
return tools
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
async with stdio_client(
StdioServerParameters(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, arguments=args)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)

View file

@ -32,6 +32,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.tool_runtime.code_interpreter", module="llama_stack.providers.inline.tool_runtime.code_interpreter",
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig", config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
), ),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::model-context-protocol",
pip_packages=["mcp"],
module="llama_stack.providers.inline.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.inline.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
),
remote_provider_spec( remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests import requests
from llama_models.llama3.api.datatypes import BuiltinTool from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl(
return provider_data.api_key return provider_data.api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return [ return [
ToolDef( ToolDef(

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass pass
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
if mcp_endpoint is None: if mcp_config is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_config is required")
tools = [] tools = []
async with sse_client(mcp_endpoint.uri) as streams: async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
tools_result = await session.list_tools() tools_result = await session.list_tools()
@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description=tool.description, description=tool.description,
parameters=parameters, parameters=parameters,
metadata={ metadata={
"endpoint": mcp_endpoint.uri, "mcp_config": mcp_config,
}, },
) )
) )
@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None: if tool.metadata is None or tool.metadata.get("mcp_config") is None:
raise ValueError(f"Tool {tool_name} does not have metadata") raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint") mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
if urlparse(endpoint).scheme not in ("http", "https"): mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams: async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
result = await session.call_tool(tool.identifier, args) result = await session.call_tool(tool.identifier, args)

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl(
return provider_data.api_key return provider_data.api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return [ return [
ToolDef( ToolDef(

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
import requests import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
MCPConfig,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl(
return provider_data.api_key return provider_data.api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self,
tool_group_id: Optional[str] = None,
mcp_config: Optional[MCPConfig] = None,
) -> List[ToolDef]: ) -> List[ToolDef]:
return [ return [
ToolDef( ToolDef(

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import os
from typing import Dict, List from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
@ -324,3 +325,35 @@ def test_rag_agent(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None] logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)
assert "Tool:query_memory" in logs_str assert "Tool:query_memory" in logs_str
def test_mcp_agent(llama_stack_client, agent_config):
llama_stack_client.toolgroups.register(
toolgroup_id="brave-search",
provider_id="model-context-protocol",
mcp_config=dict(
type="inline",
command="/Users/dineshyv/homebrew/bin/npx",
args=["-y", "@modelcontextprotocol/server-brave-search"],
env={
"BRAVE_API_KEY": os.environ["BRAVE_SEARCH_API_KEY"],
"PATH": os.environ["PATH"],
},
),
)
agent_config = {
**agent_config,
"toolgroups": [
"brave-search",
],
}
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session("test-session")
response = agent.create_turn(
messages=[{"role": "user", "content": "what won the NBA playoffs in 2024?"}],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:brave_web_search" in logs_str
assert "celtics" in logs_str.lower()