fix: RAG with documents (#1337)

Summary:
This was broken by
https://github.com/meta-llama/llama-stack/pull/1015/files#r1975394190

Test Plan:

added e2e test
This commit is contained in:
ehhuang 2025-02-28 16:51:00 -08:00 committed by GitHub
parent 7854af8b52
commit 21ec67356c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 66 additions and 0 deletions

View file

@ -533,6 +533,11 @@ class ChatAgent(ShieldRunnerMixin):
if documents: if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs) await self.handle_documents(session_id, documents, input_messages, tool_defs)
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id:
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = [] output_attachments = []
n_iter = 0 n_iter = 0

View file

@ -374,6 +374,67 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
agent_config = {
**agent_config,
"toolgroups": [
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [],
},
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
documents,
),
(
"Tell me how to use LoRA",
None,
),
]
for prompt in user_prompts:
response = rag_agent.create_turn(
messages=[
{
"role": "user",
"content": prompt[0],
}
],
documents=prompt[1],
session_id=session_id,
stream=False,
)
# rag is called
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
assert len(tool_execution_step) >= 1
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
assert "lora" in response.output_message.content.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config): def test_rag_and_code_agent(llama_stack_client, agent_config):
documents = [] documents = []
documents.append( documents.append(