address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-18 18:42:11 -08:00
parent 72dab3e4bf
commit ea0ca7454a
2 changed files with 17 additions and 9 deletions

View file

@ -11,8 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -50,8 +49,13 @@ class ToolDef(BaseModel):
@json_schema_type @json_schema_type
class MCPToolGroup(BaseModel): class MCPToolGroup(BaseModel):
type: Literal["mcp"] = "mcp" """
endpoint: str A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
@json_schema_type @json_schema_type
@ -67,7 +71,7 @@ ToolGroup = register_schema(
@json_schema_type @json_schema_type
class InvokeToolResult(BaseModel): class ToolInvocationResult(BaseModel):
content: InterleavedContent content: InterleavedContent
error_message: Optional[str] = None error_message: Optional[str] = None
error_code: Optional[int] = None error_code: Optional[int] = None
@ -80,7 +84,7 @@ class ToolStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Tools(Protocol): class Tools(Protocol):
@webmethod(route="/tool-groups/register", method="POST") @webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group( async def register_tool_group(
self, self,
name: str, name: str,
@ -112,6 +116,8 @@ class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...

View file

@ -6,7 +6,7 @@
from typing import Any, Dict from typing import Any, Dict
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import MetaReferenceToolRuntimeConfig from .config import MetaReferenceToolRuntimeConfig
@ -26,5 +26,7 @@ class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
pass pass
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult: async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
) -> ToolInvocationResult:
pass pass