test: introduce recordable mocks for Agent tests (#1268)

Summary:

Agent tests shouldn't need to run inference and tools calls repeatedly.
This PR introduces a way to record inference/tool calls and reuse them
in subsequent test runs, which makes the tests more reliable and saves
costs.

Test Plan:
Run when there's no recorded calls created (fails):
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B
```

Run when `--record-responses` to record calls:
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses
```

Run without `--record-responses` again (succeeds):
```
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
ehhuang 2025-03-03 14:48:32 -08:00 committed by GitHub
parent 816fdf289a
commit 386c806c70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 6893 additions and 29 deletions

View file

@ -3,13 +3,18 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import logging
import os
from pathlib import Path
import pytest
from fixtures.recordable_mock import RecordableMock
from llama_stack_client import LlamaStackClient
from report import Report
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api
from llama_stack.providers.tests.env import get_env_or_fail
@ -66,6 +71,12 @@ def pytest_addoption(parser):
default=384,
help="Output dimensionality of the embedding model to use for testing",
)
parser.addoption(
"--record-responses",
action="store_true",
default=False,
help="Record new API responses instead of using cached ones.",
)
@pytest.fixture(scope="session")
@ -101,6 +112,61 @@ def llama_stack_client(provider_data, text_model_id):
return client
@pytest.fixture(scope="session")
def llama_stack_client_with_mocked_inference(llama_stack_client, request):
"""
Returns a client with mocked inference APIs and tool runtime APIs that use recorded responses by default.
If --record-responses is passed, it will call the real APIs and record the responses.
"""
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
logging.warning(
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
)
return llama_stack_client
record_responses = request.config.getoption("--record-responses")
cache_dir = Path(__file__).parent / "fixtures" / "recorded_responses"
# Create a shallow copy of the client to avoid modifying the original
client = copy.copy(llama_stack_client)
# Get the inference API used by the agents implementation
agents_impl = client.async_client.impls[Api.agents]
original_inference = agents_impl.inference_api
# Create a new inference object with the same attributes
inference_mock = copy.copy(original_inference)
# Replace the methods with recordable mocks
inference_mock.chat_completion = RecordableMock(
original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses
)
inference_mock.completion = RecordableMock(
original_inference.completion, cache_dir, "text_completion", record=record_responses
)
inference_mock.embeddings = RecordableMock(
original_inference.embeddings, cache_dir, "embeddings", record=record_responses
)
# Replace the inference API in the agents implementation
agents_impl.inference_api = inference_mock
original_tool_runtime_api = agents_impl.tool_runtime_api
tool_runtime_mock = copy.copy(original_tool_runtime_api)
# Replace the methods with recordable mocks
tool_runtime_mock.invoke_tool = RecordableMock(
original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses
)
agents_impl.tool_runtime_api = tool_runtime_mock
# Also update the client.inference for consistency
client.inference = inference_mock
return client
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()