diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index d2441a9e5..706dd74f1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -66,6 +66,7 @@ from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing + from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -477,6 +478,12 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) + # append retrieved_context to the last user message + for message in input_messages[::-1]: + if isinstance(message, UserMessage): + message.context = retrieved_context + break + output_attachments = [] n_iter = 0 diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 374c7c74c..969f926ff 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -211,7 +211,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config): } codex_agent = Agent(llama_stack_client, agent_config) - session_id = codex_agent.create_session("test-session") + session_id = codex_agent.create_session(f"test-session-{uuid4()}") inflation_doc = AgentDocument( content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", mime_type="text/csv", @@ -299,7 +299,7 @@ def test_rag_agent(llama_stack_client, agent_config): ], } rag_agent = Agent(llama_stack_client, agent_config) - session_id = rag_agent.create_session("test-session") + session_id = rag_agent.create_session(f"test-session-{uuid4()}") user_prompts = [ "What are the top 5 topics that were explained? Only list succinct bullet points.", ] @@ -334,7 +334,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, - chunk_size_in_tokens=512, + chunk_size_in_tokens=128, ) agent_config = { **agent_config, @@ -347,7 +347,6 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): ], } agent = Agent(llama_stack_client, agent_config) - session_id = agent.create_session("test-session") inflation_doc = Document( document_id="test_csv", content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", @@ -355,22 +354,26 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): metadata={}, ) user_prompts = [ + ( + "Here is a csv file, can you describe it?", + [inflation_doc], + "code_interpreter", + ), ( "What are the top 5 topics that were explained? Only list succinct bullet points.", [], "query_from_memory", ), - ("What is the average yearly inflation?", [inflation_doc], "code_interpreter"), ] for prompt, docs, tool_name in user_prompts: print(f"User> {prompt}") + session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, documents=docs, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert f"Tool:{tool_name}" in logs_str