mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
feat(tests): auto-merge all model list responses and unify recordings
This commit is contained in:
parent
faf891b40c
commit
34f5da6cdc
121 changed files with 17457 additions and 16126 deletions
|
@ -30,6 +30,9 @@ from openai.types.completion_choice import CompletionChoice
|
|||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
|
@ -51,7 +54,7 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
|
@ -60,28 +63,18 @@ def setup_inference_recording():
|
|||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
bodies.
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
|
@ -134,8 +127,8 @@ class ResponseStorage:
|
|||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
# Generate unique response filename
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
short_hash = request_hash[:12]
|
||||
response_file = f"{short_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
serialized_response = dict(response)
|
||||
|
@ -147,6 +140,14 @@ class ResponseStorage:
|
|||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||
|
||||
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{short_hash}-{digest}.json"
|
||||
|
||||
response_path = self.responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
|
@ -161,19 +162,85 @@ class ResponseStorage:
|
|||
if not response_path.exists():
|
||||
return None
|
||||
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
return _recording_from_file(response_path)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
return results
|
||||
|
||||
|
||||
def _recording_from_file(response_path) -> dict[str, Any]:
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||
def _extract_model_identifiers():
|
||||
"""Extract a stable set of identifiers for model-list endpoints.
|
||||
|
||||
Supported endpoints:
|
||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
||||
Returns a list of unique identifiers or None if structure doesn't match.
|
||||
"""
|
||||
body = response["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.get("models")
|
||||
idents = [m.model for m in items]
|
||||
else:
|
||||
items = body.get("data")
|
||||
idents = [m.id for m in items]
|
||||
return sorted(set(idents))
|
||||
|
||||
identifiers = _extract_model_identifiers()
|
||||
return hashlib.sha1(("|".join(identifiers)).encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
|
||||
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
"""Return a single, unioned recording for supported model-list endpoints."""
|
||||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.models
|
||||
elif endpoint == "/v1/models":
|
||||
items = body.data
|
||||
else:
|
||||
items = []
|
||||
|
||||
for m in items:
|
||||
if endpoint == "/v1/models":
|
||||
key = m.id
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
key = m.model
|
||||
seen[key] = m
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||
canonical = records[0]
|
||||
canonical_req = canonical.get("request", {})
|
||||
if isinstance(canonical_req, dict):
|
||||
canonical_req["endpoint"] = endpoint
|
||||
if endpoint == "/v1/models":
|
||||
body = {"data": ordered, "object": "list"}
|
||||
else:
|
||||
from ollama import ListResponse
|
||||
|
||||
body = ListResponse(models=ordered)
|
||||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
|
@ -195,8 +262,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
|
||||
# Normalize request for matching
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
@ -204,7 +269,12 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
|
@ -274,12 +344,14 @@ def patch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
"embeddings_create": AsyncEmbeddings.create,
|
||||
"models_list": AsyncModels.list,
|
||||
"ollama_generate": OllamaAsyncClient.generate,
|
||||
"ollama_chat": OllamaAsyncClient.chat,
|
||||
"ollama_embed": OllamaAsyncClient.embed,
|
||||
|
@ -304,10 +376,16 @@ def patch_inference_clients():
|
|||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_models_list(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
AsyncCompletions.create = patched_completions_create
|
||||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
AsyncModels.list = patched_models_list
|
||||
|
||||
# Create patched methods for Ollama client
|
||||
async def patched_ollama_generate(self, *args, **kwargs):
|
||||
|
@ -361,11 +439,13 @@ def unpatch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
AsyncModels.list = _original_methods["openai_models_list"]
|
||||
|
||||
# Restore Ollama client methods if they were patched
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
|
@ -379,15 +459,11 @@ def unpatch_inference_clients():
|
|||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str = "live", storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
# Set defaults
|
||||
if storage_dir is None:
|
||||
storage_dir_path = Path.home() / ".llama" / "recordings"
|
||||
else:
|
||||
storage_dir_path = Path(storage_dir)
|
||||
storage_dir_path = Path(storage_dir)
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue