diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ce08e041f..2a9f4b6f7 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4521,6 +4521,31 @@ }, "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -6746,6 +6771,31 @@ }, "error_code": { "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -7595,9 +7645,37 @@ "properties": { "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, + "required": [ + "metadata" + ], "title": "RAGQueryResult" }, "QueryChunksRequest": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0e4955a5c..a2329e47a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2945,6 +2945,16 @@ components: - type: string content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - call_id @@ -4381,6 +4391,16 @@ components: type: string error_code: type: integer + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - content @@ -4954,7 +4974,19 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false + required: + - metadata title: RAGQueryResult QueryChunksRequest: type: object diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d83506dd4..e517d9c3c 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -165,6 +165,7 @@ class ToolResponse(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] content: InterleavedContent + metadata: Optional[Dict[str, Any]] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index cff8eeefe..2b9ef10d8 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -26,6 +26,7 @@ class RAGDocument(BaseModel): @json_schema_type class RAGQueryResult(BaseModel): content: Optional[InterleavedContent] = None + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index b83be127f..a4d84edbe 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel): content: InterleavedContent error_message: Optional[str] = None error_code: Optional[int] = None + metadata: Optional[Dict[str, Any]] = None class ToolStore(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index edd253356..560215b25 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -62,7 +62,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin): call_id="", tool_name=MEMORY_QUERY_TOOL, content=retrieved_context or [], + metadata=result.metadata, ) ], ), @@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin): }, ) as span: tool_execution_start_time = datetime.now() - result_messages = await execute_tool_call_maybe( + tool_call = message.tool_calls[0] + tool_result = await execute_tool_call_maybe( self.tool_runtime_api, session_id, - [message], + tool_call, toolgroup_args, tool_to_group, ) + result_messages = [ + ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + ) + ] assert len(result_messages) == 1, "Currently not supporting multiple messages" result_message = result_messages[0] span.set_attribute("output", result_message.model_dump_json()) @@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin): call_id=result_message.call_id, tool_name=result_message.tool_name, content=result_message.content, + metadata=tool_result.metadata, ) ], started_at=tool_execution_start_time, @@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( tool_runtime_api: ToolRuntime, session_id: str, - messages: List[CompletionMessage], + tool_call: ToolCall, toolgroup_args: Dict[str, Dict[str, Any]], tool_to_group: Dict[str, str], -) -> List[ToolResponseMessage]: - # While Tools.run interface takes a list of messages, - # All tools currently only run on a single message - # When this changes, we can drop this assert - # Whether to call tools on each message and aggregate - # or aggregate and call tool once, reamins to be seen. - assert len(messages) == 1, "Expected single message" - message = messages[0] - - tool_call = message.tool_calls[0] +) -> ToolInvocationResult: name = tool_call.tool_name group_name = tool_to_group.get(name, None) if group_name is None: @@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe( **tool_call_args, ), ) - - return [ - ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=result.content, - ) - ] + return result def _interpret_content_as_attachment( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a6cd57923..306bd78a6 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) - + chunks = chunks[: query_config.max_chunks] tokens = 0 picked = [] - for c in chunks[: query_config.max_chunks]: + for c in chunks: metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): text="\n=== END-RETRIEVED-CONTEXT ===\n", ), ], + metadata={ + "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + }, ) async def list_runtime_tools( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 781095d2b..23ae601e4 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -457,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config): vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, + provider_id="faiss", ) llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, @@ -492,11 +493,13 @@ def test_rag_agent(llama_stack_client, agent_config): response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert "Tool:query_from_memory" in logs_str - assert expected_kw in logs_str.lower() + # rag is called + assert response.steps[0].tool_calls[0].tool_name == "query_from_memory" + # document ids are present in metadata + assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"] + assert expected_kw in response.output_message.content def test_rag_and_code_agent(llama_stack_client, agent_config):