From bb9bf7edee3720192288ca4b3ec4e478d0c9a689 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Mar 2025 12:14:35 -0700 Subject: [PATCH] update agents test --- tests/integration/agents/test_agents.py | 115 ++++++++++++++---------- 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 7011dc02d..480e88d42 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -8,15 +8,13 @@ from typing import Any, Dict from uuid import uuid4 import pytest -from llama_stack_client import Agent, AgentEventLogger, Document -from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack.apis.agents.agents import ( AgentConfig as Server__AgentConfig, -) -from llama_stack.apis.agents.agents import ( ToolChoice, ) +from llama_stack_client import Agent, AgentEventLogger, Document +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: @@ -36,7 +34,9 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: return -1 -def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]: +def get_boiling_point_with_metadata( + liquid_name: str, celcius: bool = True +) -> Dict[str, Any]: """ Returns the boiling point of a liquid in Celcius or Fahrenheit @@ -56,7 +56,10 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D @pytest.fixture(scope="session") def agent_config(llama_stack_client_with_mocked_inference, text_model_id): - available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()] + available_shields = [ + shield.identifier + for shield in llama_stack_client_with_mocked_inference.shields.list() + ] available_shields = available_shields[:1] agent_config = dict( model=text_model_id, @@ -109,7 +112,9 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): session_id=session_id, ) - logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None] + logs = [ + str(log) for log in AgentEventLogger().log(bomb_response) if log is not None + ] logs_str = "".join(logs) assert "I can't" in logs_str @@ -170,9 +175,12 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config): Server__AgentConfig(**agent_config) -def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): +def test_builtin_tool_web_search( + llama_stack_client_with_mocked_inference, agent_config +): agent_config = { **agent_config, + "instructions": "You are a helpful assistant that can use web search to answer questions.", "tools": [ "builtin::websearch", ], @@ -184,23 +192,25 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent messages=[ { "role": "user", - "content": "Search the web and tell me who the founder of Meta is.", + "content": "Search the web and tell me what is the local time in Tokyo currently.", } ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] - logs_str = "".join(logs) - - assert "tool_execution>" in logs_str - assert "Tool:brave_search Response:" in logs_str - assert "mark zuckerberg" in logs_str.lower() - if len(agent_config["output_shields"]) > 0: - assert "No Violation" in logs_str + found_tool_execution = False + for step in response.steps: + if step.step_type == "tool_execution": + assert step.tool_calls[0].tool_name == "brave_search" + found_tool_execution = True + break + assert found_tool_execution -def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): +def test_builtin_tool_code_execution( + llama_stack_client_with_mocked_inference, agent_config +): agent_config = { **agent_config, "tools": [ @@ -229,7 +239,9 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a # This test must be run in an environment where `bwrap` is available. If you are running against a # server, this means the _server_ must have `bwrap` available. If you are using library client, then # you must have `bwrap` available in test's environment. -def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config): +def test_code_interpreter_for_attachments( + llama_stack_client_with_mocked_inference, agent_config +): agent_config = { **agent_config, "tools": [ @@ -291,7 +303,9 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config): assert "get_boiling_point" in logs_str -def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config): +def test_custom_tool_infinite_loop( + llama_stack_client_with_mocked_inference, agent_config +): client_tool = get_boiling_point agent_config = { **agent_config, @@ -314,7 +328,9 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age stream=False, ) - num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) + num_tool_calls = sum( + [1 if step.step_type == "tool_execution" else 0 for step in response.steps] + ) assert num_tool_calls <= 5 @@ -326,18 +342,25 @@ def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_co def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config): - tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none") + tool_execution_steps = run_agent_with_tool_choice( + llama_stack_client_with_mocked_inference, agent_config, "none" + ) assert len(tool_execution_steps) == 0 -def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config): +def test_tool_choice_get_boiling_point( + llama_stack_client_with_mocked_inference, agent_config +): if "llama" not in agent_config["model"].lower(): pytest.xfail("NotImplemented for non-llama models") tool_execution_steps = run_agent_with_tool_choice( llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point" ) - assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" + assert ( + len(tool_execution_steps) >= 1 + and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" + ) def run_agent_with_tool_choice(client, agent_config, tool_choice): @@ -367,8 +390,12 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): return [step for step in response.steps if step.step_type == "tool_execution"] -@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) -def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): +@pytest.mark.parametrize( + "rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"] +) +def test_rag_agent( + llama_stack_client_with_mocked_inference, agent_config, rag_tool_name +): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -417,29 +444,22 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t stream=False, ) # rag is called - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + tool_execution_step = next( + step for step in response.steps if step.step_type == "tool_execution" + ) assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" # document ids are present in metadata assert all( - doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] + doc_id.startswith("num-") + for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] ) if expected_kw: assert expected_kw in response.output_message.content.lower() -@pytest.mark.parametrize( - "tool", - [ - dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": [], - }, - ), - "builtin::rag/knowledge_search", - ], -) -def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool): +def test_rag_agent_with_attachments( + llama_stack_client_with_mocked_inference, agent_config +): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -452,7 +472,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag ] agent_config = { **agent_config, - "tools": [tool], } rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) session_id = rag_agent.create_session(f"test-session-{uuid4()}") @@ -486,10 +505,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag stream=False, ) - # rag is called - tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"] - assert len(tool_execution_step) >= 1 - assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search" assert "lora" in response.output_message.content.lower() @@ -571,7 +586,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf documents=docs, stream=False, ) - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + tool_execution_step = next( + step for step in response.steps if step.step_type == "tool_execution" + ) assert tool_execution_step.tool_calls[0].tool_name == tool_name if expected_kw: assert expected_kw in response.output_message.content.lower() @@ -581,7 +598,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)], ) -def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools): +def test_create_turn_response( + llama_stack_client_with_mocked_inference, agent_config, client_tools +): client_tool, expects_metadata = client_tools agent_config = { **agent_config,