mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
address feedback
This commit is contained in:
parent
72dab3e4bf
commit
ea0ca7454a
2 changed files with 17 additions and 9 deletions
|
@ -11,8 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
@ -50,8 +49,13 @@ class ToolDef(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class MCPToolGroup(BaseModel):
|
||||
type: Literal["mcp"] = "mcp"
|
||||
endpoint: str
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -67,7 +71,7 @@ ToolGroup = register_schema(
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class InvokeToolResult(BaseModel):
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
|
@ -80,7 +84,7 @@ class ToolStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Tools(Protocol):
|
||||
@webmethod(route="/tool-groups/register", method="POST")
|
||||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -112,6 +116,8 @@ class ToolRuntime(Protocol):
|
|||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult:
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.apis.tools import InvokeToolResult, Tool, ToolRuntime
|
||||
from llama_stack.apis.tools import Tool, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import MetaReferenceToolRuntimeConfig
|
||||
|
@ -26,5 +26,7 @@ class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> InvokeToolResult:
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue