llama-stack/tests/integration/fixtures/common.py
Ashwin Bharambe 205661bc78
fix: Use re-entrancy and concurrency safe context managers for provider data (#1498)
Concurrent requests should not trample (or reuse) each others' provider
data. Provider data should be scoped to each request.

## Test Plan

Set the uvicorn server to have a single worker process + thread by
updating the config:
```python
    uvicorn_config = {
        ...
        "workers": 1,
        "loop": "asyncio",
    }
```

Then perform the following steps on `origin/main` (without this change).

(1) Run the server using `llama stack run dev` without having
`FIREWORKS_API_KEY` in the environment.

(2) Run a test by specifying the FIREWORKS_API_KEY env var so it gets
stored in the thread local
```
pytest -s -v tests/integration/inference/test_text_inference.py \
    --stack-config http://localhost:8321 \
    --text-model accounts/fireworks/models/llama-v3p1-8b-instruct \
    -k test_text_chat_completion_with_tool_calling_and_streaming \
     --env FIREWORKS_API_KEY=<...>
``` 
Ensure you don't have any other API keys in the environment (otherwise
the bug will not reproduce due to other specifics in our testing code.)
Verify this works.

(3) Run the same command again without specifying FIREWORKS_API_KEY. See
that the request actually succeeds when it *should have failed*.


----
Now do the same tests on this branch, verify step (3) results in
failure.

Finally, run the full `test_text_inference.py` test suite with this
change, verify it succeeds.
2025-03-08 22:56:30 -08:00

207 lines
7.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import inspect
import logging
import os
import tempfile
from pathlib import Path
import pytest
import yaml
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api
from llama_stack.distribution.stack import run_config_from_adhoc_config_spec
from llama_stack.env import get_env_or_fail
from .recordable_mock import RecordableMock
@pytest.fixture(scope="session")
def provider_data():
# TODO: this needs to be generalized so each provider can have a sample provider data just
# like sample run config on which we can do replace_env_vars()
keymap = {
"TAVILY_SEARCH_API_KEY": "tavily_search_api_key",
"BRAVE_SEARCH_API_KEY": "brave_search_api_key",
"FIREWORKS_API_KEY": "fireworks_api_key",
"GEMINI_API_KEY": "gemini_api_key",
"OPENAI_API_KEY": "openai_api_key",
"TOGETHER_API_KEY": "together_api_key",
"ANTHROPIC_API_KEY": "anthropic_api_key",
"GROQ_API_KEY": "groq_api_key",
"WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key",
}
provider_data = {}
for key, value in keymap.items():
if os.environ.get(key):
provider_data[value] = os.environ[key]
return provider_data
@pytest.fixture(scope="session")
def llama_stack_client_with_mocked_inference(llama_stack_client, request):
"""
Returns a client with mocked inference APIs and tool runtime APIs that use recorded responses by default.
If --record-responses is passed, it will call the real APIs and record the responses.
"""
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
logging.warning(
"llama_stack_client_with_mocked_inference is not supported for this client, returning original client without mocking"
)
return llama_stack_client
record_responses = request.config.getoption("--record-responses")
cache_dir = Path(__file__).parent / "recorded_responses"
# Create a shallow copy of the client to avoid modifying the original
client = copy.copy(llama_stack_client)
# Get the inference API used by the agents implementation
agents_impl = client.async_client.impls[Api.agents]
original_inference = agents_impl.inference_api
# Create a new inference object with the same attributes
inference_mock = copy.copy(original_inference)
# Replace the methods with recordable mocks
inference_mock.chat_completion = RecordableMock(
original_inference.chat_completion, cache_dir, "chat_completion", record=record_responses
)
inference_mock.completion = RecordableMock(
original_inference.completion, cache_dir, "text_completion", record=record_responses
)
inference_mock.embeddings = RecordableMock(
original_inference.embeddings, cache_dir, "embeddings", record=record_responses
)
# Replace the inference API in the agents implementation
agents_impl.inference_api = inference_mock
original_tool_runtime_api = agents_impl.tool_runtime_api
tool_runtime_mock = copy.copy(original_tool_runtime_api)
# Replace the methods with recordable mocks
tool_runtime_mock.invoke_tool = RecordableMock(
original_tool_runtime_api.invoke_tool, cache_dir, "invoke_tool", record=record_responses
)
agents_impl.tool_runtime_api = tool_runtime_mock
# Also update the client.inference for consistency
client.inference = inference_mock
return client
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def client_with_models(
llama_stack_client,
text_model_id,
vision_model_id,
embedding_model_id,
embedding_dimension,
judge_model_id,
):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 384},
)
return client
@pytest.fixture(scope="session")
def available_shields(llama_stack_client):
return [shield.identifier for shield in llama_stack_client.shields.list()]
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
@pytest.fixture(autouse=True)
def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id"]
test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys()
for fixture in model_fixtures:
# Only check fixtures that are actually in the test function's signature
if fixture in actual_params and fixture in request.fixturenames and not request.getfixturevalue(fixture):
pytest.skip(f"{fixture} empty - skipping test")
@pytest.fixture(scope="session")
def llama_stack_client(request, provider_data, text_model_id):
config = request.config.getoption("--stack-config")
if not config:
config = get_env_or_fail("LLAMA_STACK_CONFIG")
if not config:
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
# check if this looks like a URL
if config.startswith("http") or "//" in config:
return LlamaStackClient(
base_url=config,
provider_data=provider_data,
)
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(), f)
config = run_config_file.name
client = LlamaStackAsLibraryClient(
config,
provider_data=provider_data,
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client