From 4c53db9f500b32447b9fe4f90b2cada3fc7f3411 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 14 Jul 2025 13:35:57 -0700 Subject: [PATCH] fix agent test_rag --- tests/integration/agents/test_agents.py | 104 +++++------------------- 1 file changed, 20 insertions(+), 84 deletions(-) diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 66c9ab829..05549cf18 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id): return agent_config +@pytest.fixture(scope="session") +def agent_config_without_safety(text_model_id): + agent_config = dict( + model=text_model_id, + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 0.0001, + "top_p": 0.9, + }, + }, + tools=[], + enable_session_persistence=False, + ) + return agent_config + + def test_agent_simple(llama_stack_client, agent_config): agent = Agent(llama_stack_client, **agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): assert expected_kw in response.output_message.content.lower() -def test_rag_agent_with_attachments(llama_stack_client, agent_config): +def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety): urls = ["llama3.rst", "lora_finetune.rst"] documents = [ # passign as url @@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): metadata={}, ), ] - rag_agent = Agent(llama_stack_client, **agent_config) + rag_agent = Agent(llama_stack_client, **agent_config_without_safety) session_id = rag_agent.create_session(f"test-session-{uuid4()}") - user_prompts = [ - ( - "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", - "grouped", - ), - ] user_prompts = [ ( "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", @@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): assert "lora" in response.output_message.content.lower() -@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") -def test_rag_and_code_agent(llama_stack_client, agent_config): - if "llama-4" in agent_config["model"].lower(): - pytest.xfail("Not working for llama4") - - documents = [] - documents.append( - Document( - document_id="nba_wiki", - content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", - metadata={}, - ) - ) - documents.append( - Document( - document_id="perplexity_wiki", - content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: - - Srinivas, the CEO, worked at OpenAI as an AI researcher. - Konwinski was among the founding team at Databricks. - Yarats, the CTO, was an AI research scientist at Meta. - Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", - metadata={}, - ) - ) - vector_db_id = f"test-vector-db-{uuid4()}" - llama_stack_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - ) - llama_stack_client.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=vector_db_id, - chunk_size_in_tokens=128, - ) - agent_config = { - **agent_config, - "tools": [ - dict( - name="builtin::rag/knowledge_search", - args={"vector_db_ids": [vector_db_id]}, - ), - "builtin::code_interpreter", - ], - } - agent = Agent(llama_stack_client, **agent_config) - user_prompts = [ - ( - "when was Perplexity the company founded?", - [], - "knowledge_search", - "2022", - ), - ( - "when was the nba created?", - [], - "knowledge_search", - "1949", - ), - ] - - for prompt, docs, tool_name, expected_kw in user_prompts: - session_id = agent.create_session(f"test-session-{uuid4()}") - response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=session_id, - documents=docs, - stream=False, - ) - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}" - if expected_kw: - assert expected_kw in response.output_message.content.lower() - - @pytest.mark.parametrize( "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],