mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 20:52:37 +00:00
fix(tests): ensure test isolation in server mode
Propagate test IDs from client to server via HTTP headers to maintain proper test isolation when running with server-based stack configs. Without this, recorded/replayed inference requests in server mode would leak across tests. Changes: - Patch client _prepare_request to inject test ID into provider data header - Sync test context from provider data on server side before storage operations - Set LLAMA_STACK_TEST_STACK_CONFIG_TYPE env var based on stack config - Configure console width for cleaner log output in CI - Add SQLITE_STORE_DIR temp directory for test data isolation
This commit is contained in:
parent
bba9957edd
commit
d5296a35f6
4 changed files with 219 additions and 106 deletions
|
|
@ -128,7 +128,10 @@ def strip_rich_markup(text):
|
||||||
|
|
||||||
class CustomRichHandler(RichHandler):
|
class CustomRichHandler(RichHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs["console"] = Console()
|
# Set a reasonable default width for console output, especially when redirected to files
|
||||||
|
console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120"))
|
||||||
|
# Don't force terminal codes to avoid ANSI escape codes in log files
|
||||||
|
kwargs["console"] = Console(width=console_width)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from openai import NOT_GIVEN
|
from openai import NOT_GIVEN, OpenAI
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -79,6 +79,96 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
||||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_test_context_from_provider_data():
|
||||||
|
"""In server mode, sync test ID from provider_data to _test_context.
|
||||||
|
|
||||||
|
This ensures that storage operations (which read from _test_context) work correctly
|
||||||
|
in server mode where the test ID arrives via HTTP header → provider_data.
|
||||||
|
|
||||||
|
Returns a token to reset _test_context, or None if no sync was needed.
|
||||||
|
"""
|
||||||
|
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||||
|
|
||||||
|
if stack_config_type != "server":
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from llama_stack.core.request_headers import PROVIDER_DATA_VAR
|
||||||
|
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
|
||||||
|
if provider_data and "__test_id" in provider_data:
|
||||||
|
test_id = provider_data["__test_id"]
|
||||||
|
return _test_context.set(test_id)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_httpx_for_test_id():
|
||||||
|
"""Patch client _prepare_request methods to inject test ID into provider data header.
|
||||||
|
|
||||||
|
This is needed for server mode where the test ID must be transported from
|
||||||
|
client to server via HTTP headers. In library_client mode, this patch is a no-op
|
||||||
|
since everything runs in the same process.
|
||||||
|
|
||||||
|
We use the _prepare_request hook that Stainless clients provide for mutating
|
||||||
|
requests after construction but before sending.
|
||||||
|
"""
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
if "llama_stack_client_prepare_request" in _original_methods:
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request
|
||||||
|
_original_methods["openai_prepare_request"] = OpenAI._prepare_request
|
||||||
|
|
||||||
|
def patched_prepare_request(self, request):
|
||||||
|
# Call original first (it's a sync method that returns None)
|
||||||
|
# Determine which original to call based on client type
|
||||||
|
if "llama_stack_client" in self.__class__.__module__:
|
||||||
|
_original_methods["llama_stack_client_prepare_request"](self, request)
|
||||||
|
_original_methods["openai_prepare_request"](self, request)
|
||||||
|
|
||||||
|
# Only inject test ID in server mode
|
||||||
|
stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||||
|
test_id = _test_context.get()
|
||||||
|
|
||||||
|
if stack_config_type == "server" and test_id:
|
||||||
|
provider_data_header = request.headers.get("X-LlamaStack-Provider-Data")
|
||||||
|
|
||||||
|
if provider_data_header:
|
||||||
|
provider_data = json.loads(provider_data_header)
|
||||||
|
else:
|
||||||
|
provider_data = {}
|
||||||
|
|
||||||
|
provider_data["__test_id"] = test_id
|
||||||
|
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
LlamaStackClient._prepare_request = patched_prepare_request
|
||||||
|
OpenAI._prepare_request = patched_prepare_request
|
||||||
|
|
||||||
|
|
||||||
|
# currently, unpatch is never called
|
||||||
|
def unpatch_httpx_for_test_id():
|
||||||
|
"""Remove client _prepare_request patches for test ID injection."""
|
||||||
|
if "llama_stack_client_prepare_request" not in _original_methods:
|
||||||
|
return
|
||||||
|
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
|
||||||
|
del _original_methods["llama_stack_client_prepare_request"]
|
||||||
|
|
||||||
|
# Also restore OpenAI client if it was patched
|
||||||
|
if "openai_prepare_request" in _original_methods:
|
||||||
|
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
|
||||||
|
del _original_methods["openai_prepare_request"]
|
||||||
|
|
||||||
|
|
||||||
def get_inference_mode() -> InferenceMode:
|
def get_inference_mode() -> InferenceMode:
|
||||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||||
|
|
||||||
|
|
@ -244,7 +334,7 @@ class ResponseStorage:
|
||||||
with open(response_path, "w") as f:
|
with open(response_path, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"test_id": _test_context.get(), # Include for debugging
|
"test_id": _test_context.get(),
|
||||||
"request": request,
|
"request": request,
|
||||||
"response": serialized_response,
|
"response": serialized_response,
|
||||||
},
|
},
|
||||||
|
|
@ -386,108 +476,115 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
else:
|
else:
|
||||||
return await original_method(self, *args, **kwargs)
|
return await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
# Get base URL based on client type
|
# In server mode, sync test ID from provider_data to _test_context for storage operations
|
||||||
if client_type == "openai":
|
test_context_token = _sync_test_context_from_provider_data()
|
||||||
base_url = str(self._client.base_url)
|
|
||||||
|
|
||||||
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
try:
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
# Get base URL based on client type
|
||||||
elif client_type == "ollama":
|
if client_type == "openai":
|
||||||
# Get base URL from the client (Ollama client uses host attribute)
|
base_url = str(self._client.base_url)
|
||||||
base_url = getattr(self, "host", "http://localhost:11434")
|
|
||||||
if not base_url.startswith("http"):
|
|
||||||
base_url = f"http://{base_url}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown client type: {client_type}")
|
|
||||||
|
|
||||||
url = base_url.rstrip("/") + endpoint
|
# the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out
|
||||||
# Special handling for Databricks URLs to avoid leaking workspace info
|
kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN}
|
||||||
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
elif client_type == "ollama":
|
||||||
if "cloud.databricks.com" in url:
|
# Get base URL from the client (Ollama client uses host attribute)
|
||||||
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
base_url = getattr(self, "host", "http://localhost:11434")
|
||||||
method = "POST"
|
if not base_url.startswith("http"):
|
||||||
headers = {}
|
base_url = f"http://{base_url}"
|
||||||
body = kwargs
|
|
||||||
|
|
||||||
request_hash = normalize_request(method, url, headers, body)
|
|
||||||
|
|
||||||
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
|
||||||
recording = None
|
|
||||||
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
|
||||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
|
||||||
if endpoint in ("/api/tags", "/v1/models"):
|
|
||||||
records = storage._model_list_responses(request_hash)
|
|
||||||
recording = _combine_model_list_responses(endpoint, records)
|
|
||||||
else:
|
else:
|
||||||
recording = storage.find_recording(request_hash)
|
raise ValueError(f"Unknown client type: {client_type}")
|
||||||
|
|
||||||
if recording:
|
url = base_url.rstrip("/") + endpoint
|
||||||
response_body = recording["response"]["body"]
|
# Special handling for Databricks URLs to avoid leaking workspace info
|
||||||
|
# e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com
|
||||||
|
if "cloud.databricks.com" in url:
|
||||||
|
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
|
||||||
|
method = "POST"
|
||||||
|
headers = {}
|
||||||
|
body = kwargs
|
||||||
|
|
||||||
if recording["response"].get("is_streaming", False):
|
request_hash = normalize_request(method, url, headers, body)
|
||||||
|
|
||||||
async def replay_stream():
|
# Try to find existing recording for REPLAY or RECORD_IF_MISSING modes
|
||||||
for chunk in response_body:
|
recording = None
|
||||||
|
if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING:
|
||||||
|
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||||
|
if endpoint in ("/api/tags", "/v1/models"):
|
||||||
|
records = storage._model_list_responses(request_hash)
|
||||||
|
recording = _combine_model_list_responses(endpoint, records)
|
||||||
|
else:
|
||||||
|
recording = storage.find_recording(request_hash)
|
||||||
|
|
||||||
|
if recording:
|
||||||
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
|
if recording["response"].get("is_streaming", False):
|
||||||
|
|
||||||
|
async def replay_stream():
|
||||||
|
for chunk in response_body:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_stream()
|
||||||
|
else:
|
||||||
|
return response_body
|
||||||
|
elif mode == InferenceMode.REPLAY:
|
||||||
|
# REPLAY mode requires recording to exist
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No recorded response found for request hash: {request_hash}\n"
|
||||||
|
f"Request: {method} {url} {body}\n"
|
||||||
|
f"Model: {body.get('model', 'unknown')}\n"
|
||||||
|
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
||||||
|
if endpoint == "/v1/models":
|
||||||
|
response = original_method(self, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# we want to store the result of the iterator, not the iterator itself
|
||||||
|
if endpoint == "/v1/models":
|
||||||
|
response = [m async for m in response]
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"method": method,
|
||||||
|
"url": url,
|
||||||
|
"headers": headers,
|
||||||
|
"body": body,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"model": body.get("model", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine if this is a streaming request based on request parameters
|
||||||
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
|
if is_streaming:
|
||||||
|
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||||
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Store the recording immediately
|
||||||
|
response_data = {"body": chunks, "is_streaming": True}
|
||||||
|
storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
|
# Return a generator that replays the stored chunks
|
||||||
|
async def replay_recorded_stream():
|
||||||
|
for chunk in chunks:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return replay_stream()
|
return replay_recorded_stream()
|
||||||
else:
|
else:
|
||||||
return response_body
|
response_data = {"body": response, "is_streaming": False}
|
||||||
elif mode == InferenceMode.REPLAY:
|
storage.store_recording(request_hash, request_data, response_data)
|
||||||
# REPLAY mode requires recording to exist
|
return response
|
||||||
raise RuntimeError(
|
|
||||||
f"No recorded response found for request hash: {request_hash}\n"
|
|
||||||
f"Request: {method} {url} {body}\n"
|
|
||||||
f"Model: {body.get('model', 'unknown')}\n"
|
|
||||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording):
|
|
||||||
if endpoint == "/v1/models":
|
|
||||||
response = original_method(self, *args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
response = await original_method(self, *args, **kwargs)
|
raise AssertionError(f"Invalid mode: {mode}")
|
||||||
|
finally:
|
||||||
# we want to store the result of the iterator, not the iterator itself
|
if test_context_token:
|
||||||
if endpoint == "/v1/models":
|
_test_context.reset(test_context_token)
|
||||||
response = [m async for m in response]
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"method": method,
|
|
||||||
"url": url,
|
|
||||||
"headers": headers,
|
|
||||||
"body": body,
|
|
||||||
"endpoint": endpoint,
|
|
||||||
"model": body.get("model", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine if this is a streaming request based on request parameters
|
|
||||||
is_streaming = body.get("stream", False)
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
|
||||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
|
||||||
chunks = []
|
|
||||||
async for chunk in response:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
# Store the recording immediately
|
|
||||||
response_data = {"body": chunks, "is_streaming": True}
|
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
|
||||||
|
|
||||||
# Return a generator that replays the stored chunks
|
|
||||||
async def replay_recorded_stream():
|
|
||||||
for chunk in chunks:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return replay_recorded_stream()
|
|
||||||
else:
|
|
||||||
response_data = {"body": response, "is_streaming": False}
|
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
|
||||||
return response
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Invalid mode: {mode}")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_inference_clients():
|
def patch_inference_clients():
|
||||||
|
|
|
||||||
|
|
@ -124,12 +124,6 @@ echo ""
|
||||||
echo "Checking llama packages"
|
echo "Checking llama packages"
|
||||||
uv pip list | grep llama
|
uv pip list | grep llama
|
||||||
|
|
||||||
# Check storage and memory before tests
|
|
||||||
echo "=== System Resources Before Tests ==="
|
|
||||||
free -h 2>/dev/null || echo "free command not available"
|
|
||||||
df -h
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
export LLAMA_STACK_CLIENT_TIMEOUT=300
|
export LLAMA_STACK_CLIENT_TIMEOUT=300
|
||||||
|
|
||||||
|
|
@ -144,6 +138,17 @@ echo "=== Applying Setup Environment Variables ==="
|
||||||
|
|
||||||
# the server needs this
|
# the server needs this
|
||||||
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
|
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
|
||||||
|
export SQLITE_STORE_DIR=$(mktemp -d)
|
||||||
|
echo "Setting SQLITE_STORE_DIR: $SQLITE_STORE_DIR"
|
||||||
|
|
||||||
|
# Determine stack config type for api_recorder test isolation
|
||||||
|
if [[ "$STACK_CONFIG" == server:* ]]; then
|
||||||
|
export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="server"
|
||||||
|
echo "Setting stack config type: server"
|
||||||
|
else
|
||||||
|
export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="library_client"
|
||||||
|
echo "Setting stack config type: library_client"
|
||||||
|
fi
|
||||||
|
|
||||||
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
|
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
|
||||||
echo "Setting up environment variables:"
|
echo "Setting up environment variables:"
|
||||||
|
|
@ -186,6 +191,8 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
||||||
echo "Llama Stack Server is already running, skipping start"
|
echo "Llama Stack Server is already running, skipping start"
|
||||||
else
|
else
|
||||||
echo "=== Starting Llama Stack Server ==="
|
echo "=== Starting Llama Stack Server ==="
|
||||||
|
# Set a reasonable log width for better readability in server.log
|
||||||
|
export LLAMA_STACK_LOG_WIDTH=120
|
||||||
nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
|
nohup llama stack run ci-tests --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
echo "Waiting for Llama Stack Server to start..."
|
echo "Waiting for Llama Stack Server to start..."
|
||||||
|
|
@ -277,11 +284,5 @@ else
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check storage and memory after tests
|
|
||||||
echo ""
|
|
||||||
echo "=== System Resources After Tests ==="
|
|
||||||
free -h 2>/dev/null || echo "free command not available"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== Integration Tests Complete ==="
|
echo "=== Integration Tests Complete ==="
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,18 @@ def pytest_sessionstart(session):
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||||
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
||||||
|
|
||||||
|
stack_config = session.config.getoption("--stack-config", default=None)
|
||||||
|
if stack_config and stack_config.startswith("server:"):
|
||||||
|
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server"
|
||||||
|
logger.info(f"Test stack config type: server (stack_config={stack_config})")
|
||||||
|
else:
|
||||||
|
os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client"
|
||||||
|
logger.info(f"Test stack config type: library_client (stack_config={stack_config})")
|
||||||
|
|
||||||
|
from llama_stack.testing.inference_recorder import patch_httpx_for_test_id
|
||||||
|
|
||||||
|
patch_httpx_for_test_id()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _track_test_context(request):
|
def _track_test_context(request):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue