add docs, fix broken ness of llama_stack/ui/package.json

This commit is contained in:
Ashwin Bharambe 2025-07-29 12:30:45 -07:00
parent 9b3a860beb
commit d7970f813c
4 changed files with 44 additions and 26 deletions

View file

@ -308,14 +308,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
inference_mode = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
if inference_mode in ["record", "replay"]:
global TEST_RECORDING_CONTEXT
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={inference_mode}")
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []

View file

@ -12,6 +12,7 @@ import os
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from enum import StrEnum
from pathlib import Path
from typing import Any, Literal, cast
@ -31,6 +32,12 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
CompletionChoice.model_rebuild()
class InferenceMode(StrEnum):
LIVE = "live"
RECORD = "record"
REPLAY = "replay"
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching."""
# Extract just the endpoint path
@ -44,23 +51,33 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
return hashlib.sha256(normalized_json.encode()).hexdigest()
def get_inference_mode() -> str:
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
def get_inference_mode() -> InferenceMode:
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
def setup_inference_recording():
"""
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
to increase their reliability and reduce reliance on expensive, external services.
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
Two environment variables are required:
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
quickly find the correct recording for a given request. The JSON files are used to store the request and response
bodies.
"""
mode = get_inference_mode()
if mode not in ["live", "record", "replay"]:
if mode not in InferenceMode:
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
if mode == "live":
# Return a no-op context manager for live mode
@contextmanager
def live_mode():
yield
return live_mode()
if mode == InferenceMode.LIVE:
return None
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
@ -203,7 +220,7 @@ class ResponseStorage:
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == "live" or _current_storage is None:
if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, *args, **kwargs)
@ -261,7 +278,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
request_hash = normalize_request(method, url, headers, body)
if _current_mode == "replay":
if _current_mode == InferenceMode.REPLAY:
recording = _current_storage.find_recording(request_hash)
if recording:
response_body = recording["response"]["body"]
@ -283,7 +300,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
)
elif _current_mode == "record":
elif _current_mode == InferenceMode.RECORD:
response = await original_method(self, *args, **kwargs)
request_data = {

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"llama-stack-client": ""0.2.16",
"llama-stack-client": "0.2.16",
"lucide-react": "^0.510.0",
"next": "15.3.3",
"next-auth": "^4.24.11",

View file

@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingUsage,
)
from llama_stack.testing.inference_recorder import (
InferenceMode,
ResponseStorage,
inference_recording,
normalize_request,
@ -169,7 +170,7 @@ class TestInferenceRecording:
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
@ -198,7 +199,7 @@ class TestInferenceRecording:
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
@ -210,7 +211,7 @@ class TestInferenceRecording:
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
@ -230,7 +231,7 @@ class TestInferenceRecording:
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
@ -247,7 +248,7 @@ class TestInferenceRecording:
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
@ -258,7 +259,7 @@ class TestInferenceRecording:
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
@ -279,7 +280,7 @@ class TestInferenceRecording:
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode="live"):
with inference_recording(mode=InferenceMode.LIVE):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(