forked from phoenix-oss/llama-stack-mirror
fix(api): don't return list for runtime tools (#1686)
# What does this PR do? Don't return list for runtime tools. Instead return Response object for pagination and consistency with other APIs. --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
b440a1dc42
commit
0a895c70d1
13 changed files with 170 additions and 108 deletions
20
docs/_static/llama-stack-spec.html
vendored
20
docs/_static/llama-stack-spec.html
vendored
|
@ -2688,9 +2688,9 @@
|
|||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/jsonl": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ToolDef"
|
||||
"$ref": "#/components/schemas/ListToolDefsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8328,6 +8328,22 @@
|
|||
],
|
||||
"title": "ListRoutesResponse"
|
||||
},
|
||||
"ListToolDefsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolDef"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListToolDefsResponse"
|
||||
},
|
||||
"ListScoringFunctionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
15
docs/_static/llama-stack-spec.yaml
vendored
15
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1855,9 +1855,9 @@ paths:
|
|||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/jsonl:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ToolDef'
|
||||
$ref: '#/components/schemas/ListToolDefsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -5732,6 +5732,17 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListRoutesResponse
|
||||
ListToolDefsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolDef'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListToolDefsResponse
|
||||
ListScoringFunctionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -132,7 +132,18 @@ def _validate_api_method_return_type(method) -> str | None:
|
|||
|
||||
return_type = hints['return']
|
||||
if is_optional_type(return_type):
|
||||
return "returns Optional type"
|
||||
return "returns Optional type where a return value is mandatory"
|
||||
|
||||
|
||||
def _validate_api_method_doesnt_return_list(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if get_origin(return_type) is list:
|
||||
return "returns a list where a PaginatedResponse or List*Response object is expected"
|
||||
|
||||
|
||||
def _validate_api_delete_method_returns_none(method) -> str | None:
|
||||
|
@ -143,7 +154,7 @@ def _validate_api_delete_method_returns_none(method) -> str | None:
|
|||
|
||||
return_type = hints['return']
|
||||
if return_type is not None and return_type is not type(None):
|
||||
return "does not return None"
|
||||
return "does not return None where None is mandatory"
|
||||
|
||||
|
||||
def _validate_list_parameters_contain_data(method) -> str | None:
|
||||
|
@ -160,13 +171,14 @@ def _validate_list_parameters_contain_data(method) -> str | None:
|
|||
return
|
||||
|
||||
if 'data' not in return_type.model_fields:
|
||||
return "does not have data attribute"
|
||||
return "does not have a mandatory data attribute containing the list of objects"
|
||||
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
_validate_api_method_return_type,
|
||||
_validate_list_parameters_contain_data,
|
||||
_validate_api_method_doesnt_return_list,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
|
|
|
@ -88,6 +88,10 @@ class ListToolsResponse(BaseModel):
|
|||
data: List[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
data: list[ToolDef]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
|
@ -148,7 +152,7 @@ class ToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
) -> ListToolDefsResponse: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
|
|
|
@ -46,11 +46,11 @@ from llama_stack.apis.scoring import (
|
|||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
|
@ -707,6 +707,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -568,7 +568,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
|
||||
for tool_def in tool_defs:
|
||||
for tool_def in tool_defs.data:
|
||||
tools.append(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
|
|
|
@ -9,10 +9,11 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
@ -46,20 +47,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
|
@ -162,27 +163,29 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
ToolDef(
|
||||
name="insert_into_memory",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="insert_into_memory",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
@ -50,20 +51,22 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
|
|
|
@ -4,12 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
@ -49,21 +50,23 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mcp import ClientSession
|
||||
|
@ -12,6 +12,7 @@ from mcp.client.sse import sse_client
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
|
@ -31,7 +32,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
) -> ListToolDefsResponse:
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
|
@ -60,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
},
|
||||
)
|
||||
)
|
||||
return tools
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
@ -49,20 +50,22 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
|
@ -50,20 +51,22 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue