feat: remove special handling of builtin::rag tool (#1015)

Summary:

Lets the model decide which tool it needs to call to respond to a query.

Test Plan:
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
```

Also evaluated on a small benchmark with 20 questions from HotpotQA.
With this PR and some prompting, the performance is 77% recall compared
to 50% currently.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1015).
* #1268
* #1239
* __->__ #1015
This commit is contained in:
ehhuang 2025-02-26 13:04:52 -08:00 committed by GitHub
parent c64f0d5888
commit bb2690f176
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 94 additions and 133 deletions

View file

@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl("query_from_memory").query(
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)

View file

@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from pydantic import TypeAdapter
from llama_stack.apis.agents import (
AgentConfig,
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_from_memory"
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
@ -517,93 +515,6 @@ class ChatAgent(ShieldRunnerMixin):
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.vector_db_id:
vector_db_ids.append(session_info.vector_db_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
tool_call=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
metadata=result.metadata,
)
],
),
)
)
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
# append retrieved_context to the last user message
for message in input_messages[::-1]:
if isinstance(message, UserMessage):
message.context = retrieved_context
break
output_attachments = []
n_iter = 0
@ -631,9 +542,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tools=[tool for tool in tool_defs.values()],
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
@ -845,8 +754,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(result_message.content):
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -1072,7 +982,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
content.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
return ToolResponseMessage(
call_id="",

View file

@ -10,6 +10,8 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
for c in chunks:
picked = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
break
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
)
)
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content,
metadata=result.metadata,
)