diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index ce08e041f..2a9f4b6f7 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4521,6 +4521,31 @@
},
"content": {
"$ref": "#/components/schemas/InterleavedContent"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
@@ -6746,6 +6771,31 @@
},
"error_code": {
"type": "integer"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
@@ -7595,9 +7645,37 @@
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent"
+ },
+ "metadata": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
}
},
"additionalProperties": false,
+ "required": [
+ "metadata"
+ ],
"title": "RAGQueryResult"
},
"QueryChunksRequest": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0e4955a5c..a2329e47a 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2945,6 +2945,16 @@ components:
- type: string
content:
$ref: '#/components/schemas/InterleavedContent'
+ metadata:
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
additionalProperties: false
required:
- call_id
@@ -4381,6 +4391,16 @@ components:
type: string
error_code:
type: integer
+ metadata:
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
additionalProperties: false
required:
- content
@@ -4954,7 +4974,19 @@ components:
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
+ metadata:
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
additionalProperties: false
+ required:
+ - metadata
title: RAGQueryResult
QueryChunksRequest:
type: object
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index d83506dd4..e517d9c3c 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -165,6 +165,7 @@ class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
+ metadata: Optional[Dict[str, Any]] = None
@field_validator("tool_name", mode="before")
@classmethod
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
index cff8eeefe..2b9ef10d8 100644
--- a/llama_stack/apis/tools/rag_tool.py
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -26,6 +26,7 @@ class RAGDocument(BaseModel):
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index b83be127f..a4d84edbe 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel):
content: InterleavedContent
error_message: Optional[str] = None
error_code: Optional[int] = None
+ metadata: Optional[Dict[str, Any]] = None
class ToolStore(Protocol):
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index edd253356..560215b25 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -62,7 +62,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
-from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
+from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
+ metadata=result.metadata,
)
],
),
@@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin):
},
) as span:
tool_execution_start_time = datetime.now()
- result_messages = await execute_tool_call_maybe(
+ tool_call = message.tool_calls[0]
+ tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
- [message],
+ tool_call,
toolgroup_args,
tool_to_group,
)
+ result_messages = [
+ ToolResponseMessage(
+ call_id=tool_call.call_id,
+ tool_name=tool_call.tool_name,
+ content=tool_result.content,
+ )
+ ]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
@@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
+ metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
@@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime,
session_id: str,
- messages: List[CompletionMessage],
+ tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
-) -> List[ToolResponseMessage]:
- # While Tools.run interface takes a list of messages,
- # All tools currently only run on a single message
- # When this changes, we can drop this assert
- # Whether to call tools on each message and aggregate
- # or aggregate and call tool once, reamins to be seen.
- assert len(messages) == 1, "Expected single message"
- message = messages[0]
-
- tool_call = message.tool_calls[0]
+) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
@@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe(
**tool_call_args,
),
)
-
- return [
- ToolResponseMessage(
- call_id=tool_call.call_id,
- tool_name=tool_call.tool_name,
- content=result.content,
- )
- ]
+ return result
def _interpret_content_as_attachment(
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index a6cd57923..306bd78a6 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
-
+ chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
- for c in chunks[: query_config.max_chunks]:
+ for c in chunks:
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
+ metadata={
+ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
+ },
)
async def list_runtime_tools(
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index 781095d2b..23ae601e4 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -457,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config):
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
+ provider_id="faiss",
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
@@ -492,11 +493,13 @@ def test_rag_agent(llama_stack_client, agent_config):
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
+ stream=False,
)
- logs = [str(log) for log in EventLogger().log(response) if log is not None]
- logs_str = "".join(logs)
- assert "Tool:query_from_memory" in logs_str
- assert expected_kw in logs_str.lower()
+ # rag is called
+ assert response.steps[0].tool_calls[0].tool_name == "query_from_memory"
+ # document ids are present in metadata
+ assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"]
+ assert expected_kw in response.output_message.content
def test_rag_and_code_agent(llama_stack_client, agent_config):