test: introduce recordable mocks for Agent tests (#1268)

Summary:

Agent tests shouldn't need to run inference and tools calls repeatedly.
This PR introduces a way to record inference/tool calls and reuse them
in subsequent test runs, which makes the tests more reliable and saves
costs.

Test Plan:
Run when there's no recorded calls created (fails):
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B
```

Run when `--record-responses` to record calls:
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses
```

Run without `--record-responses` again (succeeds):
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
ehhuang 2025-03-03 14:48:32 -08:00 committed by GitHub
parent 816fdf289a
commit 386c806c70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 6893 additions and 29 deletions

View file

@ -41,8 +41,8 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
@pytest.fixture(scope="session")
def agent_config(llama_stack_client, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client.shields.list()]
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
available_shields = available_shields[:1]
agent_config = AgentConfig(
model=text_model_id,
@ -62,8 +62,8 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config
def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, agent_config)
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
simple_hello = agent.create_turn(
@ -100,7 +100,7 @@ def test_agent_simple(llama_stack_client, agent_config):
assert "I can't" in logs_str
def test_tool_config(llama_stack_client, agent_config):
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct",
instructions="You are a helpful assistant",
@ -156,14 +156,14 @@ def test_tool_config(llama_stack_client, agent_config):
Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"builtin::websearch",
],
}
agent = Agent(llama_stack_client, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -186,14 +186,14 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -215,7 +215,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
# This test must be run in an environment where `bwrap` is available. If you are running against a
# server, this means the _server_ must have `bwrap` available. If you are using library client, then
# you must have `bwrap` available in test's environment.
def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
@ -223,7 +223,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
],
}
codex_agent = Agent(llama_stack_client, agent_config)
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -251,7 +251,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
assert "Tool:code_interpreter" in logs_str
def test_custom_tool(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
@ -259,7 +259,7 @@ def test_custom_tool(llama_stack_client, agent_config):
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -278,7 +278,7 @@ def test_custom_tool(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
@ -287,7 +287,7 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
"max_infer_iters": 5,
}
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -305,7 +305,7 @@ def test_custom_tool_infinite_loop(llama_stack_client, agent_config):
assert num_tool_calls <= 5
def test_tool_choice(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
def run_agent(tool_choice):
client_tool = get_boiling_point
@ -315,7 +315,7 @@ def test_tool_choice(llama_stack_client, agent_config):
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client, test_agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -342,7 +342,7 @@ def test_tool_choice(llama_stack_client, agent_config):
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -354,12 +354,12 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
for i, url in enumerate(urls)
]
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
llama_stack_client_with_mocked_inference.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
# small chunks help to get specific info out of the docs
@ -376,7 +376,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -401,7 +401,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -423,7 +423,7 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -462,7 +462,7 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config):
def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_config):
documents = []
documents.append(
Document(
@ -484,12 +484,12 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
llama_stack_client_with_mocked_inference.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
llama_stack_client_with_mocked_inference.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
@ -504,7 +504,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -546,7 +546,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
assert expected_kw in response.output_message.content.lower()
def test_create_turn_response(llama_stack_client, agent_config):
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
@ -555,7 +555,7 @@ def test_create_turn_response(llama_stack_client, agent_config):
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(