diff --git a/llama_stack/log.py b/llama_stack/log.py index 6f751b21d..8aee4c9a9 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -128,7 +128,10 @@ def strip_rich_markup(text): class CustomRichHandler(RichHandler): def __init__(self, *args, **kwargs): - kwargs["console"] = Console() + # Set a reasonable default width for console output, especially when redirected to files + console_width = int(os.environ.get("LLAMA_STACK_LOG_WIDTH", "120")) + # Don't force terminal codes to avoid ANSI escape codes in log files + kwargs["console"] = Console(width=console_width) super().__init__(*args, **kwargs) def emit(self, record): diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 9f8140c08..16071f80f 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -15,7 +15,7 @@ from enum import StrEnum from pathlib import Path from typing import Any, Literal, cast -from openai import NOT_GIVEN +from openai import NOT_GIVEN, OpenAI from llama_stack.log import get_logger @@ -79,6 +79,96 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict return hashlib.sha256(normalized_json.encode()).hexdigest() +def _sync_test_context_from_provider_data(): + """In server mode, sync test ID from provider_data to _test_context. + + This ensures that storage operations (which read from _test_context) work correctly + in server mode where the test ID arrives via HTTP header → provider_data. + + Returns a token to reset _test_context, or None if no sync was needed. + """ + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + + if stack_config_type != "server": + return None + + try: + from llama_stack.core.request_headers import PROVIDER_DATA_VAR + + provider_data = PROVIDER_DATA_VAR.get() + + if provider_data and "__test_id" in provider_data: + test_id = provider_data["__test_id"] + return _test_context.set(test_id) + except ImportError: + pass + + return None + + +def patch_httpx_for_test_id(): + """Patch client _prepare_request methods to inject test ID into provider data header. + + This is needed for server mode where the test ID must be transported from + client to server via HTTP headers. In library_client mode, this patch is a no-op + since everything runs in the same process. + + We use the _prepare_request hook that Stainless clients provide for mutating + requests after construction but before sending. + """ + from llama_stack_client import LlamaStackClient + + if "llama_stack_client_prepare_request" in _original_methods: + return + + _original_methods["llama_stack_client_prepare_request"] = LlamaStackClient._prepare_request + _original_methods["openai_prepare_request"] = OpenAI._prepare_request + + def patched_prepare_request(self, request): + # Call original first (it's a sync method that returns None) + # Determine which original to call based on client type + if "llama_stack_client" in self.__class__.__module__: + _original_methods["llama_stack_client_prepare_request"](self, request) + _original_methods["openai_prepare_request"](self, request) + + # Only inject test ID in server mode + stack_config_type = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client") + test_id = _test_context.get() + + if stack_config_type == "server" and test_id: + provider_data_header = request.headers.get("X-LlamaStack-Provider-Data") + + if provider_data_header: + provider_data = json.loads(provider_data_header) + else: + provider_data = {} + + provider_data["__test_id"] = test_id + request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) + + return None + + LlamaStackClient._prepare_request = patched_prepare_request + OpenAI._prepare_request = patched_prepare_request + + +# currently, unpatch is never called +def unpatch_httpx_for_test_id(): + """Remove client _prepare_request patches for test ID injection.""" + if "llama_stack_client_prepare_request" not in _original_methods: + return + + from llama_stack_client import LlamaStackClient + + LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] + del _original_methods["llama_stack_client_prepare_request"] + + # Also restore OpenAI client if it was patched + if "openai_prepare_request" in _original_methods: + OpenAI._prepare_request = _original_methods["openai_prepare_request"] + del _original_methods["openai_prepare_request"] + + def get_inference_mode() -> InferenceMode: return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower()) @@ -244,7 +334,7 @@ class ResponseStorage: with open(response_path, "w") as f: json.dump( { - "test_id": _test_context.get(), # Include for debugging + "test_id": _test_context.get(), "request": request, "response": serialized_response, }, @@ -386,108 +476,115 @@ async def _patched_inference_method(original_method, self, client_type, endpoint else: return await original_method(self, *args, **kwargs) - # Get base URL based on client type - if client_type == "openai": - base_url = str(self._client.base_url) + # In server mode, sync test ID from provider_data to _test_context for storage operations + test_context_token = _sync_test_context_from_provider_data() - # the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out - kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN} - elif client_type == "ollama": - # Get base URL from the client (Ollama client uses host attribute) - base_url = getattr(self, "host", "http://localhost:11434") - if not base_url.startswith("http"): - base_url = f"http://{base_url}" - else: - raise ValueError(f"Unknown client type: {client_type}") + try: + # Get base URL based on client type + if client_type == "openai": + base_url = str(self._client.base_url) - url = base_url.rstrip("/") + endpoint - # Special handling for Databricks URLs to avoid leaking workspace info - # e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com - if "cloud.databricks.com" in url: - url = "__databricks__" + url.split("cloud.databricks.com")[-1] - method = "POST" - headers = {} - body = kwargs - - request_hash = normalize_request(method, url, headers, body) - - # Try to find existing recording for REPLAY or RECORD_IF_MISSING modes - recording = None - if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING: - # Special handling for model-list endpoints: merge all recordings with this hash - if endpoint in ("/api/tags", "/v1/models"): - records = storage._model_list_responses(request_hash) - recording = _combine_model_list_responses(endpoint, records) + # the OpenAI client methods may pass NOT_GIVEN for unset parameters; filter these out + kwargs = {k: v for k, v in kwargs.items() if v is not NOT_GIVEN} + elif client_type == "ollama": + # Get base URL from the client (Ollama client uses host attribute) + base_url = getattr(self, "host", "http://localhost:11434") + if not base_url.startswith("http"): + base_url = f"http://{base_url}" else: - recording = storage.find_recording(request_hash) + raise ValueError(f"Unknown client type: {client_type}") - if recording: - response_body = recording["response"]["body"] + url = base_url.rstrip("/") + endpoint + # Special handling for Databricks URLs to avoid leaking workspace info + # e.g. https://adb-1234567890123456.7.cloud.databricks.com -> https://...cloud.databricks.com + if "cloud.databricks.com" in url: + url = "__databricks__" + url.split("cloud.databricks.com")[-1] + method = "POST" + headers = {} + body = kwargs - if recording["response"].get("is_streaming", False): + request_hash = normalize_request(method, url, headers, body) - async def replay_stream(): - for chunk in response_body: + # Try to find existing recording for REPLAY or RECORD_IF_MISSING modes + recording = None + if mode == InferenceMode.REPLAY or mode == InferenceMode.RECORD_IF_MISSING: + # Special handling for model-list endpoints: merge all recordings with this hash + if endpoint in ("/api/tags", "/v1/models"): + records = storage._model_list_responses(request_hash) + recording = _combine_model_list_responses(endpoint, records) + else: + recording = storage.find_recording(request_hash) + + if recording: + response_body = recording["response"]["body"] + + if recording["response"].get("is_streaming", False): + + async def replay_stream(): + for chunk in response_body: + yield chunk + + return replay_stream() + else: + return response_body + elif mode == InferenceMode.REPLAY: + # REPLAY mode requires recording to exist + raise RuntimeError( + f"No recorded response found for request hash: {request_hash}\n" + f"Request: {method} {url} {body}\n" + f"Model: {body.get('model', 'unknown')}\n" + f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" + ) + + if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording): + if endpoint == "/v1/models": + response = original_method(self, *args, **kwargs) + else: + response = await original_method(self, *args, **kwargs) + + # we want to store the result of the iterator, not the iterator itself + if endpoint == "/v1/models": + response = [m async for m in response] + + request_data = { + "method": method, + "url": url, + "headers": headers, + "body": body, + "endpoint": endpoint, + "model": body.get("model", ""), + } + + # Determine if this is a streaming request based on request parameters + is_streaming = body.get("stream", False) + + if is_streaming: + # For streaming responses, we need to collect all chunks immediately before yielding + # This ensures the recording is saved even if the generator isn't fully consumed + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store the recording immediately + response_data = {"body": chunks, "is_streaming": True} + storage.store_recording(request_hash, request_data, response_data) + + # Return a generator that replays the stored chunks + async def replay_recorded_stream(): + for chunk in chunks: yield chunk - return replay_stream() + return replay_recorded_stream() else: - return response_body - elif mode == InferenceMode.REPLAY: - # REPLAY mode requires recording to exist - raise RuntimeError( - f"No recorded response found for request hash: {request_hash}\n" - f"Request: {method} {url} {body}\n" - f"Model: {body.get('model', 'unknown')}\n" - f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record" - ) + response_data = {"body": response, "is_streaming": False} + storage.store_recording(request_hash, request_data, response_data) + return response - if mode == InferenceMode.RECORD or (mode == InferenceMode.RECORD_IF_MISSING and not recording): - if endpoint == "/v1/models": - response = original_method(self, *args, **kwargs) else: - response = await original_method(self, *args, **kwargs) - - # we want to store the result of the iterator, not the iterator itself - if endpoint == "/v1/models": - response = [m async for m in response] - - request_data = { - "method": method, - "url": url, - "headers": headers, - "body": body, - "endpoint": endpoint, - "model": body.get("model", ""), - } - - # Determine if this is a streaming request based on request parameters - is_streaming = body.get("stream", False) - - if is_streaming: - # For streaming responses, we need to collect all chunks immediately before yielding - # This ensures the recording is saved even if the generator isn't fully consumed - chunks = [] - async for chunk in response: - chunks.append(chunk) - - # Store the recording immediately - response_data = {"body": chunks, "is_streaming": True} - storage.store_recording(request_hash, request_data, response_data) - - # Return a generator that replays the stored chunks - async def replay_recorded_stream(): - for chunk in chunks: - yield chunk - - return replay_recorded_stream() - else: - response_data = {"body": response, "is_streaming": False} - storage.store_recording(request_hash, request_data, response_data) - return response - - else: - raise AssertionError(f"Invalid mode: {mode}") + raise AssertionError(f"Invalid mode: {mode}") + finally: + if test_context_token: + _test_context.reset(test_context_token) def patch_inference_clients(): diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index eee60951d..b009ad696 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -124,12 +124,6 @@ echo "" echo "Checking llama packages" uv pip list | grep llama -# Check storage and memory before tests -echo "=== System Resources Before Tests ===" -free -h 2>/dev/null || echo "free command not available" -df -h -echo "" - # Set environment variables export LLAMA_STACK_CLIENT_TIMEOUT=300 @@ -144,6 +138,17 @@ echo "=== Applying Setup Environment Variables ===" # the server needs this export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE" +export SQLITE_STORE_DIR=$(mktemp -d) +echo "Setting SQLITE_STORE_DIR: $SQLITE_STORE_DIR" + +# Determine stack config type for api_recorder test isolation +if [[ "$STACK_CONFIG" == server:* ]]; then + export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="server" + echo "Setting stack config type: server" +else + export LLAMA_STACK_TEST_STACK_CONFIG_TYPE="library_client" + echo "Setting stack config type: library_client" +fi SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash) echo "Setting up environment variables:" @@ -186,6 +191,8 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then echo "Llama Stack Server is already running, skipping start" else echo "=== Starting Llama Stack Server ===" + # Set a reasonable log width for better readability in server.log + export LLAMA_STACK_LOG_WIDTH=120 nohup llama stack run ci-tests --image-type venv > server.log 2>&1 & echo "Waiting for Llama Stack Server to start..." @@ -277,11 +284,5 @@ else exit 1 fi -# Check storage and memory after tests -echo "" -echo "=== System Resources After Tests ===" -free -h 2>/dev/null || echo "free command not available" -df -h - echo "" echo "=== Integration Tests Complete ===" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 42015a608..c0eb27b98 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -35,6 +35,18 @@ def pytest_sessionstart(session): if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ: os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay" + stack_config = session.config.getoption("--stack-config", default=None) + if stack_config and stack_config.startswith("server:"): + os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "server" + logger.info(f"Test stack config type: server (stack_config={stack_config})") + else: + os.environ["LLAMA_STACK_TEST_STACK_CONFIG_TYPE"] = "library_client" + logger.info(f"Test stack config type: library_client (stack_config={stack_config})") + + from llama_stack.testing.inference_recorder import patch_httpx_for_test_id + + patch_httpx_for_test_id() + @pytest.fixture(autouse=True) def _track_test_context(request):