mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-14 09:06:10 +00:00
test: introduce recordable mocks for Agent tests (#1268)
Summary: Agent tests shouldn't need to run inference and tools calls repeatedly. This PR introduces a way to record inference/tool calls and reuse them in subsequent test runs, which makes the tests more reliable and saves costs. Test Plan: Run when there's no recorded calls created (fails): ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B ``` Run when `--record-responses` to record calls: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses ``` Run without `--record-responses` again (succeeds): ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
816fdf289a
commit
386c806c70
7 changed files with 6893 additions and 29 deletions
|
@ -41,8 +41,8 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agent_config(llama_stack_client, text_model_id):
|
||||
available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
|
||||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
||||
available_shields = available_shields[:1]
|
||||
agent_config = AgentConfig(
|
||||
model=text_model_id,
|
||||
|
@ -62,8 +62,8 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
return agent_config
|
||||
|
||||
|
||||
def test_agent_simple(llama_stack_client, agent_config):
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
simple_hello = agent.create_turn(
|
||||
|
@ -100,7 +100,7 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_tool_config(llama_stack_client, agent_config):
|
||||
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||
common_params = dict(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
|
@ -156,14 +156,14 @@ def test_tool_config(llama_stack_client, agent_config):
|
|||
Server__AgentConfig(**agent_config)
|
||||
|
||||
|
||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -186,14 +186,14 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
|||
assert "No Violation" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -215,7 +215,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
# This test must be run in an environment where `bwrap` is available. If you are running against a
|
||||
# server, this means the _server_ must have `bwrap` available. If you are using library client, then
|
||||
# you must have `bwrap` available in test's environment.
|
||||
def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
|
||||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
|
@ -223,7 +223,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
|
|||
],
|
||||
}
|
||||
|
||||
codex_agent = Agent(llama_stack_client, agent_config)
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
||||
inflation_doc = AgentDocument(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -251,7 +251,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
|
|||
assert "Tool:code_interpreter" in logs_str
|
||||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
||||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -259,7 +259,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -278,7 +278,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
assert "get_boiling_point" in logs_str
|
||||
|
||||
|
||||
def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
||||
def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config):
|
||||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -287,7 +287,7 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
|||
"max_infer_iters": 5,
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -305,7 +305,7 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
|
|||
assert num_tool_calls <= 5
|
||||
|
||||
|
||||
def test_tool_choice(llama_stack_client, agent_config):
|
||||
def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
||||
def run_agent(tool_choice):
|
||||
client_tool = get_boiling_point
|
||||
|
||||
|
@ -315,7 +315,7 @@ def test_tool_choice(llama_stack_client, agent_config):
|
|||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -342,7 +342,7 @@ def test_tool_choice(llama_stack_client, agent_config):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -354,12 +354,12 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
llama_stack_client_with_mocked_inference.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
# small chunks help to get specific info out of the docs
|
||||
|
@ -376,7 +376,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
|||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -401,7 +401,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -423,7 +423,7 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
|||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -462,7 +462,7 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
|||
assert "lora" in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||
def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config):
|
||||
documents = []
|
||||
documents.append(
|
||||
Document(
|
||||
|
@ -484,12 +484,12 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|||
)
|
||||
)
|
||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
llama_stack_client_with_mocked_inference.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=128,
|
||||
|
@ -504,7 +504,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
inflation_doc = Document(
|
||||
document_id="test_csv",
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -546,7 +546,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_create_turn_response(llama_stack_client, agent_config):
|
||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config):
|
||||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
|
@ -555,7 +555,7 @@ def test_create_turn_response(llama_stack_client, agent_config):
|
|||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue