diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index de4d82646..c6b59e948 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -11,8 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable -from llama_stack.apis.inference import InterleavedContent - +from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -50,8 +49,13 @@ class ToolDef(BaseModel): @json_schema_type class MCPToolGroup(BaseModel): - type: Literal["mcp"] = "mcp" - endpoint: str + """ + A tool group that is defined by in a model context protocol server. + Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. + """ + + type: Literal["model_context_protocol"] = "model_context_protocol" + endpoint: URL @json_schema_type @@ -67,7 +71,7 @@ ToolGroup = register_schema( @json_schema_type -class InvokeToolResult(BaseModel): +class ToolInvocationResult(BaseModel): content: InterleavedContent error_message: Optional[str] = None error_code: Optional[int] = None @@ -80,7 +84,7 @@ class ToolStore(Protocol): @runtime_checkable @trace_protocol class Tools(Protocol): - @webmethod(route="/tool-groups/register", method="POST") + @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, name: str, @@ -112,6 +116,8 @@ class ToolRuntime(Protocol): tool_store: ToolStore @webmethod(route="/tool-runtime/invoke", method="POST") - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py index df5352221..087fd918d 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -6,7 +6,7 @@ from typing import Any, Dict -from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime +from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import MetaReferenceToolRuntimeConfig @@ -26,5 +26,7 @@ class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def unregister_tool(self, tool_id: str) -> None: pass - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: + async def invoke_tool( + self, tool_id: str, args: Dict[str, Any] + ) -> ToolInvocationResult: pass