forked from phoenix-oss/llama-stack-mirror
test: introduce recordable mocks for Agent tests (#1268)
Summary: Agent tests shouldn't need to run inference and tools calls repeatedly. This PR introduces a way to record inference/tool calls and reuse them in subsequent test runs, which makes the tests more reliable and saves costs. Test Plan: Run when there's no recorded calls created (fails): ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B ``` Run when `--record-responses` to record calls: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --record-responses ``` Run without `--record-responses` again (succeeds): ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B ```
This commit is contained in:
parent
816fdf289a
commit
386c806c70
7 changed files with 6893 additions and 29 deletions
|
@ -3,13 +3,18 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fixtures.recordable_mock import RecordableMock
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from report import Report
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
|
||||
|
@ -66,6 +71,12 @@ def pytest_addoption(parser):
|
|||
default=384,
|
||||
help="Output dimensionality of the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--record-responses",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Record new API responses instead of using cached ones.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -101,6 +112,61 @@ def llama_stack_client(provider_data, text_model_id):
|
|||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client_with_mocked_inference(llama_stack_client, request):
|
||||
"""
|
||||
Returns a client with mocked inference APIs and tool runtime APIs that use recorded responses by default.
|
||||
|
||||
If --record-responses is passed, it will call the real APIs and record the responses.
|
||||
"""
|
||||
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
logging.warning(
|
||||
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
|
||||
)
|
||||
return llama_stack_client
|
||||
|
||||
record_responses = request.config.getoption("--record-responses")
|
||||
cache_dir = Path(__file__).parent / "fixtures" / "recorded_responses"
|
||||
|
||||
# Create a shallow copy of the client to avoid modifying the original
|
||||
client = copy.copy(llama_stack_client)
|
||||
|
||||
# Get the inference API used by the agents implementation
|
||||
agents_impl = client.async_client.impls[Api.agents]
|
||||
original_inference = agents_impl.inference_api
|
||||
|
||||
# Create a new inference object with the same attributes
|
||||
inference_mock = copy.copy(original_inference)
|
||||
|
||||
# Replace the methods with recordable mocks
|
||||
inference_mock.chat_completion = RecordableMock(
|
||||
original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses
|
||||
)
|
||||
inference_mock.completion = RecordableMock(
|
||||
original_inference.completion, cache_dir, "text_completion", record=record_responses
|
||||
)
|
||||
inference_mock.embeddings = RecordableMock(
|
||||
original_inference.embeddings, cache_dir, "embeddings", record=record_responses
|
||||
)
|
||||
|
||||
# Replace the inference API in the agents implementation
|
||||
agents_impl.inference_api = inference_mock
|
||||
|
||||
original_tool_runtime_api = agents_impl.tool_runtime_api
|
||||
tool_runtime_mock = copy.copy(original_tool_runtime_api)
|
||||
|
||||
# Replace the methods with recordable mocks
|
||||
tool_runtime_mock.invoke_tool = RecordableMock(
|
||||
original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses
|
||||
)
|
||||
agents_impl.tool_runtime_api = tool_runtime_mock
|
||||
|
||||
# Also update the client.inference for consistency
|
||||
client.inference = inference_mock
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue