mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
add inline mcp provider
This commit is contained in:
parent
ffc6bd4805
commit
2c265d803c
16 changed files with 398 additions and 49 deletions
|
@ -6548,6 +6548,83 @@
|
||||||
"model_context_protocol"
|
"model_context_protocol"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"MCPConfig": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/MCPInlineConfig"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/MCPRemoteConfig"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"MCPInlineConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "inline",
|
||||||
|
"default": "inline"
|
||||||
|
},
|
||||||
|
"command": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"args": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"env": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"command"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"MCPRemoteConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "remote",
|
||||||
|
"default": "remote"
|
||||||
|
},
|
||||||
|
"mcp_endpoint": {
|
||||||
|
"$ref": "#/components/schemas/URL"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"mcp_endpoint"
|
||||||
|
]
|
||||||
|
},
|
||||||
"ToolGroup": {
|
"ToolGroup": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -6565,8 +6642,8 @@
|
||||||
"const": "tool_group",
|
"const": "tool_group",
|
||||||
"default": "tool_group"
|
"default": "tool_group"
|
||||||
},
|
},
|
||||||
"mcp_endpoint": {
|
"mcp_config": {
|
||||||
"$ref": "#/components/schemas/URL"
|
"$ref": "#/components/schemas/MCPConfig"
|
||||||
},
|
},
|
||||||
"args": {
|
"args": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -6916,8 +6993,8 @@
|
||||||
"ListRuntimeToolsRequest": {
|
"ListRuntimeToolsRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"mcp_endpoint": {
|
"mcp_config": {
|
||||||
"$ref": "#/components/schemas/URL"
|
"$ref": "#/components/schemas/MCPConfig"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
|
@ -8022,8 +8099,8 @@
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"mcp_endpoint": {
|
"mcp_config": {
|
||||||
"$ref": "#/components/schemas/URL"
|
"$ref": "#/components/schemas/MCPConfig"
|
||||||
},
|
},
|
||||||
"args": {
|
"args": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
@ -8932,6 +9009,18 @@
|
||||||
"name": "LoraFinetuningConfig",
|
"name": "LoraFinetuningConfig",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LoraFinetuningConfig\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LoraFinetuningConfig\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "MCPConfig",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPConfig\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MCPInlineConfig",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPInlineConfig\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MCPRemoteConfig",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MCPRemoteConfig\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Memory"
|
"name": "Memory"
|
||||||
},
|
},
|
||||||
|
@ -9437,6 +9526,9 @@
|
||||||
"LogEventRequest",
|
"LogEventRequest",
|
||||||
"LogSeverity",
|
"LogSeverity",
|
||||||
"LoraFinetuningConfig",
|
"LoraFinetuningConfig",
|
||||||
|
"MCPConfig",
|
||||||
|
"MCPInlineConfig",
|
||||||
|
"MCPRemoteConfig",
|
||||||
"MemoryBankDocument",
|
"MemoryBankDocument",
|
||||||
"MemoryRetrievalStep",
|
"MemoryRetrievalStep",
|
||||||
"Message",
|
"Message",
|
||||||
|
|
|
@ -1125,8 +1125,8 @@ components:
|
||||||
ListRuntimeToolsRequest:
|
ListRuntimeToolsRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
mcp_endpoint:
|
mcp_config:
|
||||||
$ref: '#/components/schemas/URL'
|
$ref: '#/components/schemas/MCPConfig'
|
||||||
type: object
|
type: object
|
||||||
LogEventRequest:
|
LogEventRequest:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
@ -1184,6 +1184,50 @@ components:
|
||||||
- rank
|
- rank
|
||||||
- alpha
|
- alpha
|
||||||
type: object
|
type: object
|
||||||
|
MCPConfig:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/MCPInlineConfig'
|
||||||
|
- $ref: '#/components/schemas/MCPRemoteConfig'
|
||||||
|
MCPInlineConfig:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
args:
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type: array
|
||||||
|
command:
|
||||||
|
type: string
|
||||||
|
env:
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
type: object
|
||||||
|
type:
|
||||||
|
const: inline
|
||||||
|
default: inline
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- command
|
||||||
|
type: object
|
||||||
|
MCPRemoteConfig:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
mcp_endpoint:
|
||||||
|
$ref: '#/components/schemas/URL'
|
||||||
|
type:
|
||||||
|
const: remote
|
||||||
|
default: remote
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- mcp_endpoint
|
||||||
|
type: object
|
||||||
MemoryBankDocument:
|
MemoryBankDocument:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1897,8 +1941,8 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
type: object
|
type: object
|
||||||
mcp_endpoint:
|
mcp_config:
|
||||||
$ref: '#/components/schemas/URL'
|
$ref: '#/components/schemas/MCPConfig'
|
||||||
provider_id:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
toolgroup_id:
|
toolgroup_id:
|
||||||
|
@ -2773,8 +2817,8 @@ components:
|
||||||
type: object
|
type: object
|
||||||
identifier:
|
identifier:
|
||||||
type: string
|
type: string
|
||||||
mcp_endpoint:
|
mcp_config:
|
||||||
$ref: '#/components/schemas/URL'
|
$ref: '#/components/schemas/MCPConfig'
|
||||||
provider_id:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
provider_resource_id:
|
provider_resource_id:
|
||||||
|
@ -5615,6 +5659,14 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/LoraFinetuningConfig"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/LoraFinetuningConfig"
|
||||||
/>
|
/>
|
||||||
name: LoraFinetuningConfig
|
name: LoraFinetuningConfig
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPConfig" />
|
||||||
|
name: MCPConfig
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPInlineConfig"
|
||||||
|
/>
|
||||||
|
name: MCPInlineConfig
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MCPRemoteConfig"
|
||||||
|
/>
|
||||||
|
name: MCPRemoteConfig
|
||||||
- name: Memory
|
- name: Memory
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument"
|
||||||
/>
|
/>
|
||||||
|
@ -5982,6 +6034,9 @@ x-tagGroups:
|
||||||
- LogEventRequest
|
- LogEventRequest
|
||||||
- LogSeverity
|
- LogSeverity
|
||||||
- LoraFinetuningConfig
|
- LoraFinetuningConfig
|
||||||
|
- MCPConfig
|
||||||
|
- MCPInlineConfig
|
||||||
|
- MCPRemoteConfig
|
||||||
- MemoryBankDocument
|
- MemoryBankDocument
|
||||||
- MemoryRetrievalStep
|
- MemoryRetrievalStep
|
||||||
- Message
|
- Message
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
@ -57,18 +57,35 @@ class ToolDef(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MCPInlineConfig(BaseModel):
|
||||||
|
type: Literal["inline"] = "inline"
|
||||||
|
command: str
|
||||||
|
args: Optional[List[str]] = None
|
||||||
|
env: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MCPRemoteConfig(BaseModel):
|
||||||
|
type: Literal["remote"] = "remote"
|
||||||
|
mcp_endpoint: URL
|
||||||
|
|
||||||
|
|
||||||
|
MCPConfig = register_schema(Union[MCPInlineConfig, MCPRemoteConfig], name="MCPConfig")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroupInput(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_config: Optional[MCPConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_config: Optional[MCPConfig] = None
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +109,7 @@ class ToolGroups(Protocol):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
|
@ -131,7 +148,9 @@ class ToolRuntime(Protocol):
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]: ...
|
) -> List[ToolDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
from llama_stack.apis.tools import MCPConfig, ToolDef, ToolRuntime
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
@ -418,8 +418,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||||
tool_group_id, mcp_endpoint
|
tool_group_id, mcp_config
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import Shield, Shields
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
from llama_stack.apis.tools import MCPConfig, Tool, ToolGroup, ToolGroups, ToolHost
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
|
@ -504,15 +504,15 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||||
toolgroup_id, mcp_endpoint
|
toolgroup_id, mcp_config
|
||||||
)
|
)
|
||||||
tool_host = (
|
tool_host = (
|
||||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
ToolHost.model_context_protocol if mcp_config else ToolHost.distribution
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
|
@ -547,7 +547,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
mcp_endpoint=mcp_endpoint,
|
mcp_config=mcp_config,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,8 +9,8 @@ import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
Tool,
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
@ -43,7 +43,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
return
|
return
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
|
|
|
@ -10,11 +10,11 @@ import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
|
@ -52,7 +52,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import ModelContextProtocolConfig
|
||||||
|
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ModelContextProtocolConfig, _deps):
|
||||||
|
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolConfig(BaseModel):
|
||||||
|
pass
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.stdio import stdio_client, StdioServerParameters
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: ModelContextProtocolConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
if mcp_config is None:
|
||||||
|
raise ValueError("mcp_config is required")
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
async with stdio_client(
|
||||||
|
StdioServerParameters(
|
||||||
|
command=mcp_config.command,
|
||||||
|
args=mcp_config.args,
|
||||||
|
env=mcp_config.env,
|
||||||
|
)
|
||||||
|
) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
parameters = []
|
||||||
|
for param_name, param_schema in tool.inputSchema.get(
|
||||||
|
"properties", {}
|
||||||
|
).items():
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(
|
||||||
|
name=param_name,
|
||||||
|
parameter_type=param_schema.get("type", "string"),
|
||||||
|
description=param_schema.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tools.append(
|
||||||
|
ToolDef(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=parameters,
|
||||||
|
metadata={
|
||||||
|
"mcp_config": mcp_config.model_dump_json(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
|
||||||
|
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||||
|
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
|
||||||
|
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
|
||||||
|
async with stdio_client(
|
||||||
|
StdioServerParameters(
|
||||||
|
command=mcp_config.command,
|
||||||
|
args=mcp_config.args,
|
||||||
|
env=mcp_config.env,
|
||||||
|
)
|
||||||
|
) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
result = await session.call_tool(tool.identifier, arguments=args)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||||
|
error_code=1 if result.isError else 0,
|
||||||
|
)
|
|
@ -32,6 +32,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
provider_type="inline::model-context-protocol",
|
||||||
|
pip_packages=["mcp"],
|
||||||
|
module="llama_stack.providers.inline.tool_runtime.model_context_protocol",
|
||||||
|
config_class="llama_stack.providers.inline.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
||||||
import requests
|
import requests
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
Tool,
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
@ -50,7 +50,9 @@ class BraveSearchToolRuntimeImpl(
|
||||||
return provider_data.api_key
|
return provider_data.api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
|
@ -30,13 +31,15 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
if mcp_endpoint is None:
|
if mcp_config is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_config is required")
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client(mcp_endpoint.uri) as streams:
|
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
|
@ -58,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
metadata={
|
metadata={
|
||||||
"endpoint": mcp_endpoint.uri,
|
"mcp_config": mcp_config,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -68,13 +71,12 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
) -> ToolInvocationResult:
|
) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
if tool.metadata is None or tool.metadata.get("mcp_config") is None:
|
||||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||||
endpoint = tool.metadata.get("endpoint")
|
mcp_config_dict = json.loads(tool.metadata.get("mcp_config"))
|
||||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
mcp_config = TypeAdapter(MCPConfig).validate_python(mcp_config_dict)
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
|
||||||
|
|
||||||
async with sse_client(endpoint) as streams:
|
async with sse_client(mcp_config.mcp_endpoint.uri) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
result = await session.call_tool(tool.identifier, args)
|
result = await session.call_tool(tool.identifier, args)
|
||||||
|
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
Tool,
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
@ -50,7 +50,9 @@ class TavilySearchToolRuntimeImpl(
|
||||||
return provider_data.api_key
|
return provider_data.api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
|
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
MCPConfig,
|
||||||
Tool,
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
@ -51,7 +51,9 @@ class WolframAlphaToolRuntimeImpl(
|
||||||
return provider_data.api_key
|
return provider_data.api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self,
|
||||||
|
tool_group_id: Optional[str] = None,
|
||||||
|
mcp_config: Optional[MCPConfig] = None,
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
ToolDef(
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
@ -324,3 +325,35 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "Tool:query_memory" in logs_str
|
assert "Tool:query_memory" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_agent(llama_stack_client, agent_config):
|
||||||
|
llama_stack_client.toolgroups.register(
|
||||||
|
toolgroup_id="brave-search",
|
||||||
|
provider_id="model-context-protocol",
|
||||||
|
mcp_config=dict(
|
||||||
|
type="inline",
|
||||||
|
command="/Users/dineshyv/homebrew/bin/npx",
|
||||||
|
args=["-y", "@modelcontextprotocol/server-brave-search"],
|
||||||
|
env={
|
||||||
|
"BRAVE_API_KEY": os.environ["BRAVE_SEARCH_API_KEY"],
|
||||||
|
"PATH": os.environ["PATH"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"toolgroups": [
|
||||||
|
"brave-search",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
session_id = agent.create_session("test-session")
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "what won the NBA playoffs in 2024?"}],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
|
logs_str = "".join(logs)
|
||||||
|
assert "Tool:brave_web_search" in logs_str
|
||||||
|
assert "celtics" in logs_str.lower()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue