test: improve generic type handling in response deserialization

Enhance the inference recorder's deserialization logic to handle
generic types like AsyncPage[Model] by stripping the generic parameters
before class resolution.

Add special handling for AsyncPage objects by converting nested model
dictionaries to SimpleNamespace objects, enabling attribute access
(e.g., .id) on the deserialized data.

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-08-13 14:08:20 +01:00
parent 711735891a
commit 91a010fb12

View file

@ -108,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
try:
# Import the original class and reconstruct the object
module_path, class_name = data["__type__"].rsplit(".", 1)
# Handle generic types (e.g. AsyncPage[Model]) by removing the generic part
if "[" in class_name and "]" in class_name:
class_name = class_name.split("[")[0]
module = __import__(module_path, fromlist=[class_name])
cls = getattr(module, class_name)
if not hasattr(cls, "model_validate"):
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
return cls.model_validate(data["__data__"])
# Special handling for AsyncPage - convert nested model dicts to proper model objects
validate_data = data["__data__"]
if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data:
# Convert model dictionaries to objects with attributes so they work with .id access
from types import SimpleNamespace
validate_data = dict(validate_data)
validate_data["data"] = [
SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"]
]
return cls.model_validate(validate_data)
except (ImportError, AttributeError, TypeError, ValueError) as e:
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
return data["__data__"]
@ -332,9 +348,11 @@ def patch_inference_clients():
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Store original methods for both OpenAI and Ollama clients
_original_methods = {
"model_list": AsyncModels.list,
"chat_completions_create": AsyncChatCompletions.create,
"completions_create": AsyncCompletions.create,
"embeddings_create": AsyncEmbeddings.create,
@ -347,6 +365,55 @@ def patch_inference_clients():
}
# Create patched methods for OpenAI client
def patched_model_list(self, *args, **kwargs):
# The original models.list() returns an AsyncPaginator that can be used with async for
# We need to create a wrapper that preserves this behavior
class PatchedAsyncPaginator:
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
self.original_method = original_method
self.instance = instance
self.client_type = client_type
self.endpoint = endpoint
self.args = args
self.kwargs = kwargs
self._result = None
def __await__(self):
# Make it awaitable like the original AsyncPaginator
async def _await():
self._result = await _patched_inference_method(
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
)
return self._result
return _await().__await__()
def __aiter__(self):
# Make it async iterable like the original AsyncPaginator
return self
async def __anext__(self):
# Get the result if we haven't already
if self._result is None:
self._result = await _patched_inference_method(
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
)
# Initialize iteration on first call
if not hasattr(self, "_iter_index"):
# Extract the data list from the result
self._data_list = self._result.data
self._iter_index = 0
# Return next item from the list
if self._iter_index >= len(self._data_list):
raise StopAsyncIteration
item = self._data_list[self._iter_index]
self._iter_index += 1
return item
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
async def patched_chat_completions_create(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@ -363,6 +430,7 @@ def patch_inference_clients():
)
# Apply OpenAI patches
AsyncModels.list = patched_model_list
AsyncChatCompletions.create = patched_chat_completions_create
AsyncCompletions.create = patched_completions_create
AsyncEmbeddings.create = patched_embeddings_create
@ -419,8 +487,10 @@ def unpatch_inference_clients():
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
from openai.resources.embeddings import AsyncEmbeddings
from openai.resources.models import AsyncModels
# Restore OpenAI client methods
AsyncModels.list = _original_methods["model_list"]
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
AsyncEmbeddings.create = _original_methods["embeddings_create"]