From cfd30d2ad5e88732aeaa034920177fbbdc1fe912 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Mar 2025 22:00:43 -0700 Subject: [PATCH] fix: update agents test (#1796) # What does this PR do? - we no longer query vector db when uploading documents as attachments [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` pytest --stack-config="http://localhost:8321" -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct ``` ``` pytest --stack-config=fireworks -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct --record-responses ``` image [//]: # (## Documentation) --- tests/integration/agents/test_agents.py | 44 ++++++++----------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 7011dc02d..2bf9baa80 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -8,15 +8,13 @@ from typing import Any, Dict from uuid import uuid4 import pytest -from llama_stack_client import Agent, AgentEventLogger, Document -from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack.apis.agents.agents import ( AgentConfig as Server__AgentConfig, -) -from llama_stack.apis.agents.agents import ( ToolChoice, ) +from llama_stack_client import Agent, AgentEventLogger, Document +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: @@ -173,6 +171,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): agent_config = { **agent_config, + "instructions": "You are a helpful assistant that can use web search to answer questions.", "tools": [ "builtin::websearch", ], @@ -184,20 +183,20 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent messages=[ { "role": "user", - "content": "Search the web and tell me who the founder of Meta is.", + "content": "Search the web and tell me what is the local time in Tokyo currently.", } ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] - logs_str = "".join(logs) - - assert "tool_execution>" in logs_str - assert "Tool:brave_search Response:" in logs_str - assert "mark zuckerberg" in logs_str.lower() - if len(agent_config["output_shields"]) > 0: - assert "No Violation" in logs_str + found_tool_execution = False + for step in response.steps: + if step.step_type == "tool_execution": + assert step.tool_calls[0].tool_name == "brave_search" + found_tool_execution = True + break + assert found_tool_execution def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): @@ -427,19 +426,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t assert expected_kw in response.output_message.content.lower() -@pytest.mark.parametrize( - "tool", - [ - dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": [], - }, - ), - "builtin::rag/knowledge_search", - ], -) -def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool): +def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -452,7 +439,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag ] agent_config = { **agent_config, - "tools": [tool], } rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) session_id = rag_agent.create_session(f"test-session-{uuid4()}") @@ -486,10 +472,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag stream=False, ) - # rag is called - tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"] - assert len(tool_execution_step) >= 1 - assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search" assert "lora" in response.output_message.content.lower()