mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 18:46:16 +00:00
add docs, fix broken ness of llama_stack/ui/package.json
This commit is contained in:
parent
9b3a860beb
commit
d7970f813c
4 changed files with 44 additions and 26 deletions
|
|
@ -308,14 +308,14 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||||
) -> dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
inference_mode = os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
if inference_mode in ["record", "replay"]:
|
|
||||||
global TEST_RECORDING_CONTEXT
|
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
|
|
||||||
|
global TEST_RECORDING_CONTEXT
|
||||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
if TEST_RECORDING_CONTEXT:
|
||||||
logger.info(f"Inference recording enabled: mode={inference_mode}")
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
|
@ -31,6 +32,12 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
|
||||||
CompletionChoice.model_rebuild()
|
CompletionChoice.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceMode(StrEnum):
|
||||||
|
LIVE = "live"
|
||||||
|
RECORD = "record"
|
||||||
|
REPLAY = "replay"
|
||||||
|
|
||||||
|
|
||||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||||
"""Create a normalized hash of the request for consistent matching."""
|
"""Create a normalized hash of the request for consistent matching."""
|
||||||
# Extract just the endpoint path
|
# Extract just the endpoint path
|
||||||
|
|
@ -44,23 +51,33 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
||||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_inference_mode() -> str:
|
def get_inference_mode() -> InferenceMode:
|
||||||
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||||
|
|
||||||
|
|
||||||
def setup_inference_recording():
|
def setup_inference_recording():
|
||||||
|
"""
|
||||||
|
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||||
|
to increase their reliability and reduce reliance on expensive, external services.
|
||||||
|
|
||||||
|
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||||
|
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||||
|
|
||||||
|
Two environment variables are required:
|
||||||
|
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||||
|
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||||
|
|
||||||
|
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||||
|
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||||
|
bodies.
|
||||||
|
"""
|
||||||
mode = get_inference_mode()
|
mode = get_inference_mode()
|
||||||
|
|
||||||
if mode not in ["live", "record", "replay"]:
|
if mode not in InferenceMode:
|
||||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||||
|
|
||||||
if mode == "live":
|
if mode == InferenceMode.LIVE:
|
||||||
# Return a no-op context manager for live mode
|
return None
|
||||||
@contextmanager
|
|
||||||
def live_mode():
|
|
||||||
yield
|
|
||||||
|
|
||||||
return live_mode()
|
|
||||||
|
|
||||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||||
|
|
@ -203,7 +220,7 @@ class ResponseStorage:
|
||||||
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
|
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
|
||||||
global _current_mode, _current_storage
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
if _current_mode == "live" or _current_storage is None:
|
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||||
# Normal operation
|
# Normal operation
|
||||||
return await original_method(self, *args, **kwargs)
|
return await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -261,7 +278,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
||||||
|
|
||||||
request_hash = normalize_request(method, url, headers, body)
|
request_hash = normalize_request(method, url, headers, body)
|
||||||
|
|
||||||
if _current_mode == "replay":
|
if _current_mode == InferenceMode.REPLAY:
|
||||||
recording = _current_storage.find_recording(request_hash)
|
recording = _current_storage.find_recording(request_hash)
|
||||||
if recording:
|
if recording:
|
||||||
response_body = recording["response"]["body"]
|
response_body = recording["response"]["body"]
|
||||||
|
|
@ -283,7 +300,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
||||||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif _current_mode == "record":
|
elif _current_mode == InferenceMode.RECORD:
|
||||||
response = await original_method(self, *args, **kwargs)
|
response = await original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
"@radix-ui/react-tooltip": "^1.2.6",
|
"@radix-ui/react-tooltip": "^1.2.6",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"llama-stack-client": ""0.2.16",
|
"llama-stack-client": "0.2.16",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
)
|
)
|
||||||
from llama_stack.testing.inference_recorder import (
|
from llama_stack.testing.inference_recorder import (
|
||||||
|
InferenceMode,
|
||||||
ResponseStorage,
|
ResponseStorage,
|
||||||
inference_recording,
|
inference_recording,
|
||||||
normalize_request,
|
normalize_request,
|
||||||
|
|
@ -169,7 +170,7 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
|
|
@ -198,7 +199,7 @@ class TestInferenceRecording:
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||||
# First, record a response
|
# First, record a response
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
|
|
@ -210,7 +211,7 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
# Now test replay mode - should not call the original method
|
# Now test replay mode - should not call the original method
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
|
|
@ -230,7 +231,7 @@ class TestInferenceRecording:
|
||||||
"""Test that replay mode fails when no recording is found."""
|
"""Test that replay mode fails when no recording is found."""
|
||||||
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
|
||||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="No recorded response found"):
|
with pytest.raises(RuntimeError, match="No recorded response found"):
|
||||||
|
|
@ -247,7 +248,7 @@ class TestInferenceRecording:
|
||||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||||
# Record
|
# Record
|
||||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||||
with inference_recording(mode="record", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(
|
||||||
|
|
@ -258,7 +259,7 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||||
with inference_recording(mode="replay", storage_dir=str(temp_storage_dir)):
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(
|
||||||
|
|
@ -279,7 +280,7 @@ class TestInferenceRecording:
|
||||||
return real_openai_chat_response
|
return real_openai_chat_response
|
||||||
|
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
with inference_recording(mode="live"):
|
with inference_recording(mode=InferenceMode.LIVE):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
Loading…
Add table
Add a link
Reference in a new issue