fix(api): don't return list for runtime tools (#1686)

# What does this PR do?

Don't return list for runtime tools. Instead return Response object for
pagination and consistency with other APIs.

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 03:53:11 -04:00 committed by GitHub
parent b440a1dc42
commit 0a895c70d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 170 additions and 108 deletions

View file

@ -2688,9 +2688,9 @@
"200": { "200": {
"description": "OK", "description": "OK",
"content": { "content": {
"application/jsonl": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ToolDef" "$ref": "#/components/schemas/ListToolDefsResponse"
} }
} }
} }
@ -8328,6 +8328,22 @@
], ],
"title": "ListRoutesResponse" "title": "ListRoutesResponse"
}, },
"ListToolDefsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDef"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListToolDefsResponse"
},
"ListScoringFunctionsResponse": { "ListScoringFunctionsResponse": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -1855,9 +1855,9 @@ paths:
'200': '200':
description: OK description: OK
content: content:
application/jsonl: application/json:
schema: schema:
$ref: '#/components/schemas/ToolDef' $ref: '#/components/schemas/ListToolDefsResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -5732,6 +5732,17 @@ components:
required: required:
- data - data
title: ListRoutesResponse title: ListRoutesResponse
ListToolDefsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/ToolDef'
additionalProperties: false
required:
- data
title: ListToolDefsResponse
ListScoringFunctionsResponse: ListScoringFunctionsResponse:
type: object type: object
properties: properties:

View file

@ -132,7 +132,18 @@ def _validate_api_method_return_type(method) -> str | None:
return_type = hints['return'] return_type = hints['return']
if is_optional_type(return_type): if is_optional_type(return_type):
return "returns Optional type" return "returns Optional type where a return value is mandatory"
def _validate_api_method_doesnt_return_list(method) -> str | None:
hints = get_type_hints(method)
if 'return' not in hints:
return "has no return type annotation"
return_type = hints['return']
if get_origin(return_type) is list:
return "returns a list where a PaginatedResponse or List*Response object is expected"
def _validate_api_delete_method_returns_none(method) -> str | None: def _validate_api_delete_method_returns_none(method) -> str | None:
@ -143,7 +154,7 @@ def _validate_api_delete_method_returns_none(method) -> str | None:
return_type = hints['return'] return_type = hints['return']
if return_type is not None and return_type is not type(None): if return_type is not None and return_type is not type(None):
return "does not return None" return "does not return None where None is mandatory"
def _validate_list_parameters_contain_data(method) -> str | None: def _validate_list_parameters_contain_data(method) -> str | None:
@ -160,13 +171,14 @@ def _validate_list_parameters_contain_data(method) -> str | None:
return return
if 'data' not in return_type.model_fields: if 'data' not in return_type.model_fields:
return "does not have data attribute" return "does not have a mandatory data attribute containing the list of objects"
_VALIDATORS = { _VALIDATORS = {
"GET": [ "GET": [
_validate_api_method_return_type, _validate_api_method_return_type,
_validate_list_parameters_contain_data, _validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list,
], ],
"DELETE": [ "DELETE": [
_validate_api_delete_method_returns_none, _validate_api_delete_method_returns_none,

View file

@ -88,6 +88,10 @@ class ListToolsResponse(BaseModel):
data: List[Tool] data: List[Tool]
class ListToolDefsResponse(BaseModel):
data: list[ToolDef]
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolGroups(Protocol): class ToolGroups(Protocol):
@ -148,7 +152,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ... ) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:

View file

@ -46,11 +46,11 @@ from llama_stack.apis.scoring import (
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
RAGToolRuntime, RAGToolRuntime,
ToolDef,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
@ -707,6 +707,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -568,7 +568,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs: for tool_def in tool_defs.data:
tools.append( tools.append(
ToolWithACL( ToolWithACL(
identifier=tool_def.name, identifier=tool_def.name,

View file

@ -9,10 +9,11 @@ import asyncio
import logging import logging
import os import os
import tempfile import tempfile
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -46,8 +47,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="code_interpreter", name="code_interpreter",
description="Execute code", description="Execute code",
@ -60,6 +62,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
], ],
) )
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"] script = kwargs["code"]

View file

@ -20,6 +20,7 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
@ -162,11 +163,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically # Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without # by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals. # encountering fatals.
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="insert_into_memory", name="insert_into_memory",
description="Insert documents into memory", description="Insert documents into memory",
@ -183,6 +185,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
], ],
), ),
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", []) vector_db_ids = kwargs.get("vector_db_ids", [])

View file

@ -5,12 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import httpx import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -50,8 +51,9 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="web_search", name="web_search",
description="Search the web using Bing Search API", description="Search the web using Bing Search API",
@ -64,6 +66,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
], ],
) )
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()

View file

@ -4,12 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import httpx import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -49,8 +50,9 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="web_search", name="web_search",
description="Search the web for information", description="Search the web for information",
@ -64,6 +66,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
built_in_type=BuiltinTool.brave_search, built_in_type=BuiltinTool.brave_search,
) )
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from mcp import ClientSession from mcp import ClientSession
@ -12,6 +12,7 @@ from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
@ -31,7 +32,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
@ -60,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
}, },
) )
) )
return tools return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import httpx import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -49,8 +50,9 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="web_search", name="web_search",
description="Search the web for information", description="Search the web for information",
@ -63,6 +65,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
], ],
) )
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()

View file

@ -5,12 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import httpx import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
@ -50,8 +51,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> ListToolDefsResponse:
return [ return ListToolDefsResponse(
data=[
ToolDef( ToolDef(
name="wolfram_alpha", name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge", description="Query WolframAlpha for computational knowledge",
@ -64,6 +66,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
], ],
) )
] ]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()