mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 18:46:16 +00:00
add docs, fix broken ness of llama_stack/ui/package.json
This commit is contained in:
parent
9b3a860beb
commit
d7970f813c
4 changed files with 44 additions and 26 deletions
|
|
@ -308,14 +308,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
inference_mode = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||
if inference_mode in ["record", "replay"]:
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={inference_mode}")
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import os
|
|||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
|
|
@ -31,6 +32,12 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
|
|||
CompletionChoice.model_rebuild()
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching."""
|
||||
# Extract just the endpoint path
|
||||
|
|
@ -44,23 +51,33 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_inference_mode() -> str:
|
||||
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
bodies.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in ["live", "record", "replay"]:
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == "live":
|
||||
# Return a no-op context manager for live mode
|
||||
@contextmanager
|
||||
def live_mode():
|
||||
yield
|
||||
|
||||
return live_mode()
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
|
|
@ -203,7 +220,7 @@ class ResponseStorage:
|
|||
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == "live" or _current_storage is None:
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -261,7 +278,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == "replay":
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
|
@ -283,7 +300,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == "record":
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
request_data = {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-stack-client": ""0.2.16",
|
||||
"llama-stack-client": "0.2.16",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.testing.inference_recorder import (
|
||||
InferenceMode,
|
||||
ResponseStorage,
|
||||
inference_recording,
|
||||
normalize_request,
|
||||
|
|
@ -169,7 +170,7 @@ class TestInferenceRecording:
|
|||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
|
|
@ -198,7 +199,7 @@ class TestInferenceRecording:
|
|||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
|
|
@ -210,7 +211,7 @@ class TestInferenceRecording:
|
|||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
|
|
@ -230,7 +231,7 @@ class TestInferenceRecording:
|
|||
"""Test that replay mode fails when no recording is found."""
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||
|
|
@ -247,7 +248,7 @@ class TestInferenceRecording:
|
|||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
|
|
@ -258,7 +259,7 @@ class TestInferenceRecording:
|
|||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.embeddings.create(
|
||||
|
|
@ -279,7 +280,7 @@ class TestInferenceRecording:
|
|||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode="live"):
|
||||
with inference_recording(mode=InferenceMode.LIVE):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
Loading…
Add table
Add a link
Reference in a new issue