diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index eca7364d7..d2441a9e5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -476,9 +476,6 @@ class ChatAgent(ShieldRunnerMixin): ) span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if retrieved_context: - last_message = input_messages[-1] - last_message.context = retrieved_context output_attachments = [] diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 4a8fdd36a..016c1f5e2 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -312,3 +312,69 @@ def test_rag_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_from_memory" in logs_str + + +def test_rag_and_code_agent(llama_stack_client, agent_config): + urls = ["chat.rst"] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + vector_db_id = "test-vector-db" + llama_stack_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + ) + llama_stack_client.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=512, + ) + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::rag", + args={"vector_db_ids": [vector_db_id]}, + ), + "builtin::code_interpreter", + ], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session("test-session") + inflation_doc = Document( + document_id="test_csv", + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", + metadata={}, + ) + user_prompts = [ + ( + "What are the top 5 topics that were explained? Only list succinct bullet points.", + documents, + "query_from_memory", + ), + ("What is the average yearly inflation?", [inflation_doc], "code_interpreter"), + ] + + for prompt, docs, tool_name in user_prompts: + print(f"User> {prompt}") + response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=session_id, + documents=docs, + ) + logs = [] + for log in EventLogger().log(response): + logs.append(str(log)) + log.print() + + # logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert f"Tool:{tool_name}" in logs_str