mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 05:19:40 +00:00
add docs, fix broken ness of llama_stack/ui/package.json
This commit is contained in:
parent
9b3a860beb
commit
d7970f813c
4 changed files with 44 additions and 26 deletions
|
|
@ -12,6 +12,7 @@ import os
|
|||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
|
|
@ -31,6 +32,12 @@ CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "len
|
|||
CompletionChoice.model_rebuild()
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching."""
|
||||
# Extract just the endpoint path
|
||||
|
|
@ -44,23 +51,33 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_inference_mode() -> str:
|
||||
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
bodies.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in ["live", "record", "replay"]:
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == "live":
|
||||
# Return a no-op context manager for live mode
|
||||
@contextmanager
|
||||
def live_mode():
|
||||
yield
|
||||
|
||||
return live_mode()
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
|
|
@ -203,7 +220,7 @@ class ResponseStorage:
|
|||
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == "live" or _current_storage is None:
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -261,7 +278,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == "replay":
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
|
@ -283,7 +300,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == "record":
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
request_data = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue