feat: tool outputs metadata (#1155)

Summary:

Allows tools to output metadata. This is useful for evaluating tool
outputs, e.g. RAG tool will output document IDs, which can be used to
score recall.

Will need to make a similar change on the client side to support
ClientTool outputting metadata.

Test Plan:

LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/client-sdk/agents/test_agents.py
This commit is contained in:
ehhuang 2025-02-21 13:15:31 -08:00 committed by GitHub
parent 36162c8c82
commit 25fddccfd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 141 additions and 28 deletions

View file

@ -4521,6 +4521,31 @@
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6746,6 +6771,31 @@
}, },
"error_code": { "error_code": {
"type": "integer" "type": "integer"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7595,9 +7645,37 @@
"properties": { "properties": {
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult" "title": "RAGQueryResult"
}, },
"QueryChunksRequest": { "QueryChunksRequest": {

View file

@ -2945,6 +2945,16 @@ components:
- type: string - type: string
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id
@ -4381,6 +4391,16 @@ components:
type: string type: string
error_code: error_code:
type: integer type: integer
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false additionalProperties: false
required: required:
- content - content
@ -4954,7 +4974,19 @@ components:
properties: properties:
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false additionalProperties: false
required:
- metadata
title: RAGQueryResult title: RAGQueryResult
QueryChunksRequest: QueryChunksRequest:
type: object type: object

View file

@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
content: InterleavedContent content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod

View file

@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
@json_schema_type @json_schema_type
class RAGQueryResult(BaseModel): class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None content: Optional[InterleavedContent] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type

View file

@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
content: InterleavedContent content: InterleavedContent
error_message: Optional[str] = None error_message: Optional[str] = None
error_code: Optional[int] = None error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class ToolStore(Protocol): class ToolStore(Protocol):

View file

@ -62,7 +62,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id="", call_id="",
tool_name=MEMORY_QUERY_TOOL, tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [], content=retrieved_context or [],
metadata=result.metadata,
) )
], ],
), ),
@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
}, },
) as span: ) as span:
tool_execution_start_time = datetime.now() tool_execution_start_time = datetime.now()
result_messages = await execute_tool_call_maybe( tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api, self.tool_runtime_api,
session_id, session_id,
[message], tool_call,
toolgroup_args, toolgroup_args,
tool_to_group, tool_to_group,
) )
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages" assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0] result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json()) span.set_attribute("output", result_message.model_dump_json())
@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id=result_message.call_id, call_id=result_message.call_id,
tool_name=result_message.tool_name, tool_name=result_message.tool_name,
content=result_message.content, content=result_message.content,
metadata=tool_result.metadata,
) )
], ],
started_at=tool_execution_start_time, started_at=tool_execution_start_time,
@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe( async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
session_id: str, session_id: str,
messages: List[CompletionMessage], tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]], toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str], tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]: ) -> ToolInvocationResult:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
# When this changes, we can drop this assert
# Whether to call tools on each message and aggregate
# or aggregate and call tool once, reamins to be seen.
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
group_name = tool_to_group.get(name, None) group_name = tool_to_group.get(name, None)
if group_name is None: if group_name is None:
@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe(
**tool_call_args, **tool_call_args,
), ),
) )
return result
return [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=result.content,
)
]
def _interpret_content_as_attachment( def _interpret_content_as_attachment(

View file

@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score # sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0 tokens = 0
picked = [] picked = []
for c in chunks[: query_config.max_chunks]: for c in chunks:
metadata = c.metadata metadata = c.metadata
tokens += metadata["token_count"] tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:
@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text="\n=== END-RETRIEVED-CONTEXT ===\n", text="\n=== END-RETRIEVED-CONTEXT ===\n",
), ),
], ],
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
) )
async def list_runtime_tools( async def list_runtime_tools(

View file

@ -457,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config):
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss",
) )
llama_stack_client.tool_runtime.rag_tool.insert( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
@ -492,11 +493,13 @@ def test_rag_agent(llama_stack_client, agent_config):
response = rag_agent.create_turn( response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
session_id=session_id, session_id=session_id,
stream=False,
) )
logs = [str(log) for log in EventLogger().log(response) if log is not None] # rag is called
logs_str = "".join(logs) assert response.steps[0].tool_calls[0].tool_name == "query_from_memory"
assert "Tool:query_from_memory" in logs_str # document ids are present in metadata
assert expected_kw in logs_str.lower() assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"]
assert expected_kw in response.output_message.content
def test_rag_and_code_agent(llama_stack_client, agent_config): def test_rag_and_code_agent(llama_stack_client, agent_config):