From bb2690f176f56b760770a0921fe71cd94715b3ac Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 13:04:52 -0800 Subject: [PATCH] feat: remove special handling of builtin::rag tool (#1015) Summary: Lets the model decide which tool it needs to call to respond to a query. Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B ``` Also evaluated on a small benchmark with 20 questions from HotpotQA. With this PR and some prompting, the performance is 77% recall compared to 50% currently. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1015). * #1268 * #1239 * __->__ #1015 --- llama_stack/distribution/routers/routers.py | 2 +- .../agents/meta_reference/agent_instance.py | 108 ++---------------- .../inline/tool_runtime/rag/memory.py | 60 +++++++--- tests/client-sdk/agents/test_agents.py | 57 ++++++--- 4 files changed, 94 insertions(+), 133 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a7c0d63e5..b0cb50e42 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: - return await self.routing_table.get_provider_impl("query_from_memory").query( + return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4a1421245..64cd41636 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from pydantic import TypeAdapter from llama_stack.apis.agents import ( AgentConfig, @@ -62,7 +61,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import ( ToolParamDefinition, ) from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence @@ -84,7 +82,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_from_memory" +MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" @@ -517,93 +515,6 @@ class ChatAgent(ShieldRunnerMixin): if documents: await self.handle_documents(session_id, documents, input_messages, tool_defs) - if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0: - with tracing.span(MEMORY_QUERY_TOOL) as span: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - - args = toolgroup_args.get(RAG_TOOL_GROUP, {}) - vector_db_ids = args.get("vector_db_ids", []) - query_config = args.get("query_config") - if query_config: - query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) - else: - # handle someone passing an empty dict - query_config = RAGQueryConfig() - - session_info = await self.storage.get_session_info(session_id) - - # if the session has a memory bank id, let the memory tool use it - if session_info.vector_db_id: - vector_db_ids.append(session_info.vector_db_id) - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - delta=ToolCallDelta( - parse_status=ToolCallParseStatus.succeeded, - tool_call=ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ), - ), - ) - ) - ) - result = await self.tool_runtime_api.rag_tool.query( - content=concat_interleaved_content([msg.content for msg in input_messages]), - vector_db_ids=vector_db_ids, - query_config=query_config, - ) - retrieved_context = result.content - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[ - ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ) - ], - tool_responses=[ - ToolResponse( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - content=retrieved_context or [], - metadata=result.metadata, - ) - ], - ), - ) - ) - ) - span.set_attribute("input", [m.model_dump_json() for m in input_messages]) - span.set_attribute("output", retrieved_context) - span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - - # append retrieved_context to the last user message - for message in input_messages[::-1]: - if isinstance(message, UserMessage): - message.context = retrieved_context - break - output_attachments = [] n_iter = 0 @@ -631,9 +542,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[ - tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP - ], + tools=[tool for tool in tool_defs.values()], tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -845,8 +754,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially - - if out_attachment := _interpret_content_as_attachment(result_message.content): + if (type(result_message.content) is str) and ( + out_attachment := _interpret_content_as_attachment(result_message.content) + ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message # with the summary. We keep all generated attachments and then attach them to final message @@ -1072,7 +982,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa else: raise ValueError(f"Unsupported URL {url}") - content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n')) + content.append( + TextContentItem( + text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' + ) + ) return ToolResponseMessage( call_id="", diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 306bd78a6..4b3f7d9e7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -10,6 +10,8 @@ import secrets import string from typing import Any, Dict, List, Optional +from pydantic import TypeAdapter + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -23,6 +25,7 @@ from llama_stack.apis.tools import ( RAGToolRuntime, ToolDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO @@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) chunks = chunks[: query_config.max_chunks] + tokens = 0 - picked = [] - for c in chunks: + picked = [ + TextContentItem( + text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ] + for i, c in enumerate(chunks): metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): break picked.append( TextContentItem( - text=f"id:{metadata['document_id']}; content:{c.content}", + text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", ) ) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) return RAGQueryResult( - content=[ - TextContentItem( - text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - ), - *picked, - TextContentItem( - text="\n=== END-RETRIEVED-CONTEXT ===\n", - ), - ], + content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], }, @@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # by the LLM. The method is only implemented so things like /tools can list without # encountering fatals. return [ - ToolDef( - name="query_from_memory", - description="Retrieve context from memory", - ), ToolDef( name="insert_into_memory", description="Insert documents into memory", ), + ToolDef( + name="knowledge_search", + description="Search for information in a database.", + parameters=[ + ToolParameter( + name="query", + description="The query to search for. Can be a natural language sentence or keywords.", + parameter_type="string", + ), + ], + ), ] async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: - raise RuntimeError( - "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" + vector_db_ids = kwargs.get("vector_db_ids", []) + query_config = kwargs.get("query_config") + if query_config: + query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) + else: + # handle someone passing an empty dict + query_config = RAGQueryConfig() + + query = kwargs["query"] + result = await self.query( + content=query, + vector_db_ids=vector_db_ids, + query_config=query_config, + ) + + return ToolInvocationResult( + content=result.content, + metadata=result.metadata, ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 876a9baf9..8e2c793e6 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id): sampling_params={ "strategy": { "type": "top_p", - "temperature": 1.0, + "temperature": 0.0001, "top_p": 0.9, }, }, @@ -496,23 +496,36 @@ def test_rag_agent(llama_stack_client, agent_config): ) # rag is called tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory" + assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" # document ids are present in metadata - assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"] - assert expected_kw in response.output_message.content.lower() + assert all( + doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] + ) + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_rag_and_code_agent(llama_stack_client, agent_config): - urls = ["chat.rst"] - documents = [ + documents = [] + documents.append( Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", + document_id="nba_wiki", + content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", metadata={}, ) - for i, url in enumerate(urls) - ] + ) + documents.append( + Document( + document_id="perplexity_wiki", + content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: + + Srinivas, the CEO, worked at OpenAI as an AI researcher. + Konwinski was among the founding team at Databricks. + Yarats, the CTO, was an AI research scientist at Meta. + Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", + metadata={}, + ) + ) vector_db_id = f"test-vector-db-{uuid4()}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, @@ -546,24 +559,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): "Here is a csv file, can you describe it?", [inflation_doc], "code_interpreter", + "", ), ( - "What are the top 5 topics that were explained? Only list succinct bullet points.", + "when was Perplexity the company founded?", [], - "query_from_memory", + "knowledge_search", + "2022", + ), + ( + "when was the nba created?", + [], + "knowledge_search", + "1949", ), ] - for prompt, docs, tool_name in user_prompts: + for prompt, docs, tool_name, expected_kw in user_prompts: session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, documents=docs, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert f"Tool:{tool_name}" in logs_str + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == tool_name + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_create_turn_response(llama_stack_client, agent_config):