mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
feat: remove special handling of builtin::rag tool (#1015)
Summary: Lets the model decide which tool it needs to call to respond to a query. Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B ``` Also evaluated on a small benchmark with 20 questions from HotpotQA. With this PR and some prompting, the performance is 77% recall compared to 50% currently. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1015). * #1268 * #1239 * __->__ #1015
This commit is contained in:
parent
c64f0d5888
commit
bb2690f176
4 changed files with 94 additions and 133 deletions
|
@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
content, vector_db_ids, query_config
|
content, vector_db_ids, query_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
|
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_from_memory"
|
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
|
@ -517,93 +515,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
|
|
||||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
|
||||||
step_id = str(uuid.uuid4())
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
|
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
|
||||||
query_config = args.get("query_config")
|
|
||||||
if query_config:
|
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
||||||
else:
|
|
||||||
# handle someone passing an empty dict
|
|
||||||
query_config = RAGQueryConfig()
|
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
|
||||||
if session_info.vector_db_id:
|
|
||||||
vector_db_ids.append(session_info.vector_db_id)
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
|
||||||
tool_call=ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await self.tool_runtime_api.rag_tool.query(
|
|
||||||
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
|
||||||
vector_db_ids=vector_db_ids,
|
|
||||||
query_config=query_config,
|
|
||||||
)
|
|
||||||
retrieved_context = result.content
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=ToolExecutionStep(
|
|
||||||
step_id=step_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_responses=[
|
|
||||||
ToolResponse(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
content=retrieved_context or [],
|
|
||||||
metadata=result.metadata,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
|
||||||
span.set_attribute("output", retrieved_context)
|
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
|
||||||
|
|
||||||
# append retrieved_context to the last user message
|
|
||||||
for message in input_messages[::-1]:
|
|
||||||
if isinstance(message, UserMessage):
|
|
||||||
message.context = retrieved_context
|
|
||||||
break
|
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
|
@ -631,9 +542,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=[tool for tool in tool_defs.values()],
|
||||||
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
|
||||||
],
|
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -845,8 +754,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
if (type(result_message.content) is str) and (
|
||||||
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||||
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
|
@ -1072,7 +982,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
content.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
|
|
@ -10,6 +10,8 @@ import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
|
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||||
chunks = chunks[: query_config.max_chunks]
|
chunks = chunks[: query_config.max_chunks]
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = [
|
||||||
for c in chunks:
|
TextContentItem(
|
||||||
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for i, c in enumerate(chunks):
|
||||||
metadata = c.metadata
|
metadata = c.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
break
|
break
|
||||||
picked.append(
|
picked.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
|
||||||
return RAGQueryResult(
|
return RAGQueryResult(
|
||||||
content=[
|
content=picked,
|
||||||
TextContentItem(
|
|
||||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
|
||||||
),
|
|
||||||
*picked,
|
|
||||||
TextContentItem(
|
|
||||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
},
|
},
|
||||||
|
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
# by the LLM. The method is only implemented so things like /tools can list without
|
# by the LLM. The method is only implemented so things like /tools can list without
|
||||||
# encountering fatals.
|
# encountering fatals.
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
|
||||||
name="query_from_memory",
|
|
||||||
description="Retrieve context from memory",
|
|
||||||
),
|
|
||||||
ToolDef(
|
ToolDef(
|
||||||
name="insert_into_memory",
|
name="insert_into_memory",
|
||||||
description="Insert documents into memory",
|
description="Insert documents into memory",
|
||||||
),
|
),
|
||||||
|
ToolDef(
|
||||||
|
name="knowledge_search",
|
||||||
|
description="Search for information in a database.",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||||
|
parameter_type="string",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||||
raise RuntimeError(
|
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
query_config = kwargs.get("query_config")
|
||||||
|
if query_config:
|
||||||
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||||
|
else:
|
||||||
|
# handle someone passing an empty dict
|
||||||
|
query_config = RAGQueryConfig()
|
||||||
|
|
||||||
|
query = kwargs["query"]
|
||||||
|
result = await self.query(
|
||||||
|
content=query,
|
||||||
|
vector_db_ids=vector_db_ids,
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=result.content,
|
||||||
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": {
|
"strategy": {
|
||||||
"type": "top_p",
|
"type": "top_p",
|
||||||
"temperature": 1.0,
|
"temperature": 0.0001,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -496,23 +496,36 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
# rag is called
|
# rag is called
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
|
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
|
||||||
# document ids are present in metadata
|
# document ids are present in metadata
|
||||||
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
assert all(
|
||||||
assert expected_kw in response.output_message.content.lower()
|
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||||
|
)
|
||||||
|
if expected_kw:
|
||||||
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
urls = ["chat.rst"]
|
documents = []
|
||||||
documents = [
|
documents.append(
|
||||||
Document(
|
Document(
|
||||||
document_id=f"num-{i}",
|
document_id="nba_wiki",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
||||||
mime_type="text/plain",
|
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
)
|
||||||
]
|
documents.append(
|
||||||
|
Document(
|
||||||
|
document_id="perplexity_wiki",
|
||||||
|
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
||||||
|
|
||||||
|
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
||||||
|
Konwinski was among the founding team at Databricks.
|
||||||
|
Yarats, the CTO, was an AI research scientist at Meta.
|
||||||
|
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
)
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||||
llama_stack_client.vector_dbs.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
|
@ -546,24 +559,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
"Here is a csv file, can you describe it?",
|
"Here is a csv file, can you describe it?",
|
||||||
[inflation_doc],
|
[inflation_doc],
|
||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
"when was Perplexity the company founded?",
|
||||||
[],
|
[],
|
||||||
"query_from_memory",
|
"knowledge_search",
|
||||||
|
"2022",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"when was the nba created?",
|
||||||
|
[],
|
||||||
|
"knowledge_search",
|
||||||
|
"1949",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt, docs, tool_name in user_prompts:
|
for prompt, docs, tool_name, expected_kw in user_prompts:
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||||
logs_str = "".join(logs)
|
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
||||||
assert f"Tool:{tool_name}" in logs_str
|
if expected_kw:
|
||||||
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_create_turn_response(llama_stack_client, agent_config):
|
def test_create_turn_response(llama_stack_client, agent_config):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue