mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat(tests): auto-merge all model list responses and unify recordings (#3320)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Python Package Build Test / build (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Unit Tests / unit-tests (3.12) (push) Failing after 14s
UI Tests / ui-tests (22) (push) Successful in 1m7s
Pre-commit / pre-commit (push) Successful in 2m34s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 4s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Python Package Build Test / build (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 14s
Unit Tests / unit-tests (3.12) (push) Failing after 14s
UI Tests / ui-tests (22) (push) Successful in 1m7s
Pre-commit / pre-commit (push) Successful in 2m34s
One needed to specify record-replay related environment variables for running integration tests. We could not use defaults because integration tests could be run against Ollama instances which could be running different models. For example, text vs vision tests needed separate instances of Ollama because a single instance typically cannot serve both of these models if you assume the standard CI worker configuration on Github. As a result, `client.list()` as returned by the Ollama client would be different between these runs and we'd end up overwriting responses. This PR "solves" it by adding a small amount of complexity -- we store model list responses specially, keyed by the hashes of the models they return. At replay time, we merge all of them and pretend that we have the union of all models available. ## Test Plan Re-recorded all the tests using `scripts/integration-tests.sh --inference-mode record`, including the vision tests.
This commit is contained in:
parent
d948e63340
commit
c3d3a0b833
122 changed files with 17460 additions and 16129 deletions
|
@ -30,6 +30,9 @@ from openai.types.completion_choice import CompletionChoice
|
|||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
|
@ -51,7 +54,7 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
|
@ -60,28 +63,18 @@ def setup_inference_recording():
|
|||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
bodies.
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
|
@ -134,8 +127,8 @@ class ResponseStorage:
|
|||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
# Generate unique response filename
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
short_hash = request_hash[:12]
|
||||
response_file = f"{short_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
serialized_response = dict(response)
|
||||
|
@ -147,6 +140,14 @@ class ResponseStorage:
|
|||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||
|
||||
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{short_hash}-{digest}.json"
|
||||
|
||||
response_path = self.responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
|
@ -161,19 +162,85 @@ class ResponseStorage:
|
|||
if not response_path.exists():
|
||||
return None
|
||||
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
return _recording_from_file(response_path)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
return results
|
||||
|
||||
|
||||
def _recording_from_file(response_path) -> dict[str, Any]:
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||
def _extract_model_identifiers():
|
||||
"""Extract a stable set of identifiers for model-list endpoints.
|
||||
|
||||
Supported endpoints:
|
||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
||||
Returns a list of unique identifiers or None if structure doesn't match.
|
||||
"""
|
||||
body = response["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.get("models")
|
||||
idents = [m.model for m in items]
|
||||
else:
|
||||
items = body.get("data")
|
||||
idents = [m.id for m in items]
|
||||
return sorted(set(idents))
|
||||
|
||||
identifiers = _extract_model_identifiers()
|
||||
return hashlib.sha1(("|".join(identifiers)).encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
|
||||
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
"""Return a single, unioned recording for supported model-list endpoints."""
|
||||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.models
|
||||
elif endpoint == "/v1/models":
|
||||
items = body.data
|
||||
else:
|
||||
items = []
|
||||
|
||||
for m in items:
|
||||
if endpoint == "/v1/models":
|
||||
key = m.id
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
key = m.model
|
||||
seen[key] = m
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||
canonical = records[0]
|
||||
canonical_req = canonical.get("request", {})
|
||||
if isinstance(canonical_req, dict):
|
||||
canonical_req["endpoint"] = endpoint
|
||||
if endpoint == "/v1/models":
|
||||
body = {"data": ordered, "object": "list"}
|
||||
else:
|
||||
from ollama import ListResponse
|
||||
|
||||
body = ListResponse(models=ordered)
|
||||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
|
@ -195,8 +262,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
|
||||
# Normalize request for matching
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
@ -204,7 +269,12 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
|
@ -274,12 +344,14 @@ def patch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
"embeddings_create": AsyncEmbeddings.create,
|
||||
"models_list": AsyncModels.list,
|
||||
"ollama_generate": OllamaAsyncClient.generate,
|
||||
"ollama_chat": OllamaAsyncClient.chat,
|
||||
"ollama_embed": OllamaAsyncClient.embed,
|
||||
|
@ -304,10 +376,16 @@ def patch_inference_clients():
|
|||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_models_list(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
AsyncCompletions.create = patched_completions_create
|
||||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
AsyncModels.list = patched_models_list
|
||||
|
||||
# Create patched methods for Ollama client
|
||||
async def patched_ollama_generate(self, *args, **kwargs):
|
||||
|
@ -361,11 +439,13 @@ def unpatch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
AsyncModels.list = _original_methods["models_list"]
|
||||
|
||||
# Restore Ollama client methods if they were patched
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
|
@ -379,16 +459,10 @@ def unpatch_inference_clients():
|
|||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str = "live", storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
# Set defaults
|
||||
if storage_dir is None:
|
||||
storage_dir_path = Path.home() / ".llama" / "recordings"
|
||||
else:
|
||||
storage_dir_path = Path(storage_dir)
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
|
@ -397,7 +471,9 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
|
|||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
_current_storage = ResponseStorage(storage_dir_path)
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue