diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index d3a9872f1..7aeeb32d8 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -8,13 +8,15 @@ from typing import Any, Dict from uuid import uuid4 import pytest +from llama_stack_client import Agent, AgentEventLogger, Document +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack.apis.agents.agents import ( AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( ToolChoice, ) -from llama_stack_client import Agent, AgentEventLogger, Document -from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: @@ -34,9 +36,7 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: return -1 -def get_boiling_point_with_metadata( - liquid_name: str, celcius: bool = True -) -> Dict[str, Any]: +def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]: """ Returns the boiling point of a liquid in Celcius or Fahrenheit @@ -56,10 +56,7 @@ def get_boiling_point_with_metadata( @pytest.fixture(scope="session") def agent_config(llama_stack_client_with_mocked_inference, text_model_id): - available_shields = [ - shield.identifier - for shield in llama_stack_client_with_mocked_inference.shields.list() - ] + available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()] available_shields = available_shields[:1] agent_config = dict( model=text_model_id, @@ -112,9 +109,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): session_id=session_id, ) - logs = [ - str(log) for log in AgentEventLogger().log(bomb_response) if log is not None - ] + logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None] logs_str = "".join(logs) assert "I can't" in logs_str @@ -175,9 +170,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config): Server__AgentConfig(**agent_config) -def test_builtin_tool_web_search( - llama_stack_client_with_mocked_inference, agent_config -): +def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): agent_config = { **agent_config, "instructions": "You are a helpful assistant that can use web search to answer questions.", @@ -208,9 +201,7 @@ def test_builtin_tool_web_search( assert found_tool_execution -def test_builtin_tool_code_execution( - llama_stack_client_with_mocked_inference, agent_config -): +def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): agent_config = { **agent_config, "tools": [ @@ -239,9 +230,7 @@ def test_builtin_tool_code_execution( # This test must be run in an environment where `bwrap` is available. If you are running against a # server, this means the _server_ must have `bwrap` available. If you are using library client, then # you must have `bwrap` available in test's environment. -def test_code_interpreter_for_attachments( - llama_stack_client_with_mocked_inference, agent_config -): +def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config): agent_config = { **agent_config, "tools": [ @@ -303,9 +292,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config): assert "get_boiling_point" in logs_str -def test_custom_tool_infinite_loop( - llama_stack_client_with_mocked_inference, agent_config -): +def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config): client_tool = get_boiling_point agent_config = { **agent_config, @@ -328,9 +315,7 @@ def test_custom_tool_infinite_loop( stream=False, ) - num_tool_calls = sum( - [1 if step.step_type == "tool_execution" else 0 for step in response.steps] - ) + num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) assert num_tool_calls <= 5 @@ -342,25 +327,18 @@ def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_co def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config): - tool_execution_steps = run_agent_with_tool_choice( - llama_stack_client_with_mocked_inference, agent_config, "none" - ) + tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none") assert len(tool_execution_steps) == 0 -def test_tool_choice_get_boiling_point( - llama_stack_client_with_mocked_inference, agent_config -): +def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config): if "llama" not in agent_config["model"].lower(): pytest.xfail("NotImplemented for non-llama models") tool_execution_steps = run_agent_with_tool_choice( llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point" ) - assert ( - len(tool_execution_steps) >= 1 - and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" - ) + assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" def run_agent_with_tool_choice(client, agent_config, tool_choice): @@ -390,12 +368,8 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): return [step for step in response.steps if step.step_type == "tool_execution"] -@pytest.mark.parametrize( - "rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"] -) -def test_rag_agent( - llama_stack_client_with_mocked_inference, agent_config, rag_tool_name -): +@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) +def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -444,22 +418,17 @@ def test_rag_agent( stream=False, ) # rag is called - tool_execution_step = next( - step for step in response.steps if step.step_type == "tool_execution" - ) + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" # document ids are present in metadata assert all( - doc_id.startswith("num-") - for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] + doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] ) if expected_kw: assert expected_kw in response.output_message.content.lower() -def test_rag_agent_with_attachments( - llama_stack_client_with_mocked_inference, agent_config -): +def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -574,9 +543,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf documents=docs, stream=False, ) - tool_execution_step = next( - step for step in response.steps if step.step_type == "tool_execution" - ) + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") assert tool_execution_step.tool_calls[0].tool_name == tool_name if expected_kw: assert expected_kw in response.output_message.content.lower() @@ -586,9 +553,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)], ) -def test_create_turn_response( - llama_stack_client_with_mocked_inference, agent_config, client_tools -): +def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools): client_tool, expects_metadata = client_tools agent_config = { **agent_config,