fix(api): don't return list for runtime tools (#1686)

# What does this PR do?

Don't return list for runtime tools. Instead return Response object for
pagination and consistency with other APIs.

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 03:53:11 -04:00 committed by GitHub
parent b440a1dc42
commit 0a895c70d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 170 additions and 108 deletions

View file

@ -2688,9 +2688,9 @@
"200": {
"description": "OK",
"content": {
"application/jsonl": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ToolDef"
"$ref": "#/components/schemas/ListToolDefsResponse"
}
}
}
@ -8328,6 +8328,22 @@
],
"title": "ListRoutesResponse"
},
"ListToolDefsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDef"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListToolDefsResponse"
},
"ListScoringFunctionsResponse": {
"type": "object",
"properties": {

View file

@ -1855,9 +1855,9 @@ paths:
'200':
description: OK
content:
application/jsonl:
application/json:
schema:
$ref: '#/components/schemas/ToolDef'
$ref: '#/components/schemas/ListToolDefsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -5732,6 +5732,17 @@ components:
required:
- data
title: ListRoutesResponse
ListToolDefsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/ToolDef'
additionalProperties: false
required:
- data
title: ListToolDefsResponse
ListScoringFunctionsResponse:
type: object
properties:

View file

@ -132,7 +132,18 @@ def _validate_api_method_return_type(method) -> str | None:
return_type = hints['return']
if is_optional_type(return_type):
return "returns Optional type"
return "returns Optional type where a return value is mandatory"
def _validate_api_method_doesnt_return_list(method) -> str | None:
hints = get_type_hints(method)
if 'return' not in hints:
return "has no return type annotation"
return_type = hints['return']
if get_origin(return_type) is list:
return "returns a list where a PaginatedResponse or List*Response object is expected"
def _validate_api_delete_method_returns_none(method) -> str | None:
@ -143,7 +154,7 @@ def _validate_api_delete_method_returns_none(method) -> str | None:
return_type = hints['return']
if return_type is not None and return_type is not type(None):
return "does not return None"
return "does not return None where None is mandatory"
def _validate_list_parameters_contain_data(method) -> str | None:
@ -160,13 +171,14 @@ def _validate_list_parameters_contain_data(method) -> str | None:
return
if 'data' not in return_type.model_fields:
return "does not have data attribute"
return "does not have a mandatory data attribute containing the list of objects"
_VALIDATORS = {
"GET": [
_validate_api_method_return_type,
_validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list,
],
"DELETE": [
_validate_api_delete_method_returns_none,

View file

@ -88,6 +88,10 @@ class ListToolsResponse(BaseModel):
data: List[Tool]
class ListToolDefsResponse(BaseModel):
data: list[ToolDef]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@ -148,7 +152,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:

View file

@ -46,11 +46,11 @@ from llama_stack.apis.scoring import (
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
@ -707,6 +707,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -568,7 +568,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs:
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,

View file

@ -9,10 +9,11 @@ import asyncio
import logging
import os
import tempfile
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -46,8 +47,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="code_interpreter",
description="Execute code",
@ -60,6 +62,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"]

View file

@ -20,6 +20,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -162,11 +163,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
@ -183,6 +185,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
],
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,8 +51,9 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
@ -64,6 +66,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -4,12 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -49,8 +50,9 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web for information",
@ -64,6 +66,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
built_in_type=BuiltinTool.brave_search,
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from urllib.parse import urlparse
from mcp import ClientSession
@ -12,6 +12,7 @@ from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
@ -31,7 +32,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
@ -60,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
},
)
)
return tools
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -49,8 +50,9 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web for information",
@ -63,6 +65,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,8 +51,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
@ -64,6 +66,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()