diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index e78f493a6..be1961ecc 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -7,6 +7,7 @@ from __future__ import annotations # for forward references import hashlib +import inspect import json import os from collections.abc import Generator @@ -198,16 +199,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str: Supported endpoints: - '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ] - - '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ] + - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ] Returns a list of unique identifiers or None if structure doesn't match. """ - body = response["body"] - if endpoint == "/api/tags": - items = body.get("models") - idents = [m.model for m in items] - else: - items = body.get("data") - idents = [m.id for m in items] + items = response["body"] + idents = [m.model if endpoint == "/api/tags" else m.id for m in items] return sorted(set(idents)) identifiers = _extract_model_identifiers() @@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) seen: dict[str, dict[str, Any]] = {} for rec in records: body = rec["response"]["body"] - if endpoint == "/api/tags": - items = body.models - elif endpoint == "/v1/models": - items = body.data - else: - items = [] - - for m in items: - if endpoint == "/v1/models": + if endpoint == "/v1/models": + for m in body: key = m.id - else: + seen[key] = m + elif endpoint == "/api/tags": + for m in body.models: key = m.model - seen[key] = m + seen[key] = m ordered = [seen[k] for k in sorted(seen.keys())] canonical = records[0] canonical_req = canonical.get("request", {}) if isinstance(canonical_req, dict): canonical_req["endpoint"] = endpoint - if endpoint == "/v1/models": - body = {"data": ordered, "object": "list"} - else: + body = ordered + if endpoint == "/api/tags": from ollama import ListResponse body = ListResponse(models=ordered) @@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if _current_mode == InferenceMode.LIVE or _current_storage is None: # Normal operation - return await original_method(self, *args, **kwargs) + if inspect.iscoroutinefunction(original_method): + return await original_method(self, *args, **kwargs) + else: + return original_method(self, *args, **kwargs) # Get base URL based on client type if client_type == "openai": @@ -282,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if recording: response_body = recording["response"]["body"] - if recording["response"].get("is_streaming", False): + if recording["response"].get("is_streaming", False) or endpoint == "/v1/models": async def replay_stream(): for chunk in response_body: @@ -300,7 +293,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint ) elif _current_mode == InferenceMode.RECORD: - response = await original_method(self, *args, **kwargs) + if inspect.iscoroutinefunction(original_method): + response = await original_method(self, *args, **kwargs) + else: + response = original_method(self, *args, **kwargs) request_data = { "method": method, @@ -314,7 +310,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint # Determine if this is a streaming request based on request parameters is_streaming = body.get("stream", False) - if is_streaming: + if is_streaming or endpoint == "/v1/models": # For streaming responses, we need to collect all chunks immediately before yielding # This ensures the recording is saved even if the generator isn't fully consumed chunks = [] @@ -380,10 +376,19 @@ def patch_inference_clients(): _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs ) - async def patched_models_list(self, *args, **kwargs): - return await _patched_inference_method( - _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs - ) + def patched_models_list(self, *args, **kwargs): + import asyncio + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + lambda: asyncio.run( + _patched_inference_method( + _original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs + ) + ) + ) + return future.result() # Apply OpenAI patches AsyncChatCompletions.create = patched_chat_completions_create