diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 67a46a1c5..0b2c01a1b 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -108,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any: try: # Import the original class and reconstruct the object module_path, class_name = data["__type__"].rsplit(".", 1) + + # Handle generic types (e.g. AsyncPage[Model]) by removing the generic part + if "[" in class_name and "]" in class_name: + class_name = class_name.split("[")[0] + module = __import__(module_path, fromlist=[class_name]) cls = getattr(module, class_name) if not hasattr(cls, "model_validate"): raise ValueError(f"Pydantic class {cls} does not support model_validate?") - return cls.model_validate(data["__data__"]) + # Special handling for AsyncPage - convert nested model dicts to proper model objects + validate_data = data["__data__"] + if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data: + # Convert model dictionaries to objects with attributes so they work with .id access + from types import SimpleNamespace + + validate_data = dict(validate_data) + validate_data["data"] = [ + SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"] + ] + + return cls.model_validate(validate_data) except (ImportError, AttributeError, TypeError, ValueError) as e: logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") return data["__data__"] @@ -332,9 +348,11 @@ def patch_inference_clients(): from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings + from openai.resources.models import AsyncModels # Store original methods for both OpenAI and Ollama clients _original_methods = { + "model_list": AsyncModels.list, "chat_completions_create": AsyncChatCompletions.create, "completions_create": AsyncCompletions.create, "embeddings_create": AsyncEmbeddings.create, @@ -347,6 +365,55 @@ def patch_inference_clients(): } # Create patched methods for OpenAI client + def patched_model_list(self, *args, **kwargs): + # The original models.list() returns an AsyncPaginator that can be used with async for + # We need to create a wrapper that preserves this behavior + class PatchedAsyncPaginator: + def __init__(self, original_method, instance, client_type, endpoint, args, kwargs): + self.original_method = original_method + self.instance = instance + self.client_type = client_type + self.endpoint = endpoint + self.args = args + self.kwargs = kwargs + self._result = None + + def __await__(self): + # Make it awaitable like the original AsyncPaginator + async def _await(): + self._result = await _patched_inference_method( + self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs + ) + return self._result + + return _await().__await__() + + def __aiter__(self): + # Make it async iterable like the original AsyncPaginator + return self + + async def __anext__(self): + # Get the result if we haven't already + if self._result is None: + self._result = await _patched_inference_method( + self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs + ) + + # Initialize iteration on first call + if not hasattr(self, "_iter_index"): + # Extract the data list from the result + self._data_list = self._result.data + self._iter_index = 0 + + # Return next item from the list + if self._iter_index >= len(self._data_list): + raise StopAsyncIteration + item = self._data_list[self._iter_index] + self._iter_index += 1 + return item + + return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs) + async def patched_chat_completions_create(self, *args, **kwargs): return await _patched_inference_method( _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs @@ -363,6 +430,7 @@ def patch_inference_clients(): ) # Apply OpenAI patches + AsyncModels.list = patched_model_list AsyncChatCompletions.create = patched_chat_completions_create AsyncCompletions.create = patched_completions_create AsyncEmbeddings.create = patched_embeddings_create @@ -419,8 +487,10 @@ def unpatch_inference_clients(): from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions from openai.resources.completions import AsyncCompletions from openai.resources.embeddings import AsyncEmbeddings + from openai.resources.models import AsyncModels # Restore OpenAI client methods + AsyncModels.list = _original_methods["model_list"] AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncCompletions.create = _original_methods["completions_create"] AsyncEmbeddings.create = _original_methods["embeddings_create"]