mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
test: improve generic type handling in response deserialization
Enhance the inference recorder's deserialization logic to handle generic types like AsyncPage[Model] by stripping the generic parameters before class resolution. Add special handling for AsyncPage objects by converting nested model dictionaries to SimpleNamespace objects, enabling attribute access (e.g., .id) on the deserialized data. Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
711735891a
commit
91a010fb12
1 changed files with 71 additions and 1 deletions
|
@ -108,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
||||||
try:
|
try:
|
||||||
# Import the original class and reconstruct the object
|
# Import the original class and reconstruct the object
|
||||||
module_path, class_name = data["__type__"].rsplit(".", 1)
|
module_path, class_name = data["__type__"].rsplit(".", 1)
|
||||||
|
|
||||||
|
# Handle generic types (e.g. AsyncPage[Model]) by removing the generic part
|
||||||
|
if "[" in class_name and "]" in class_name:
|
||||||
|
class_name = class_name.split("[")[0]
|
||||||
|
|
||||||
module = __import__(module_path, fromlist=[class_name])
|
module = __import__(module_path, fromlist=[class_name])
|
||||||
cls = getattr(module, class_name)
|
cls = getattr(module, class_name)
|
||||||
|
|
||||||
if not hasattr(cls, "model_validate"):
|
if not hasattr(cls, "model_validate"):
|
||||||
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
|
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
|
||||||
|
|
||||||
return cls.model_validate(data["__data__"])
|
# Special handling for AsyncPage - convert nested model dicts to proper model objects
|
||||||
|
validate_data = data["__data__"]
|
||||||
|
if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data:
|
||||||
|
# Convert model dictionaries to objects with attributes so they work with .id access
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
validate_data = dict(validate_data)
|
||||||
|
validate_data["data"] = [
|
||||||
|
SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return cls.model_validate(validate_data)
|
||||||
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
||||||
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
||||||
return data["__data__"]
|
return data["__data__"]
|
||||||
|
@ -332,9 +348,11 @@ def patch_inference_clients():
|
||||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
from openai.resources.completions import AsyncCompletions
|
from openai.resources.completions import AsyncCompletions
|
||||||
from openai.resources.embeddings import AsyncEmbeddings
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
|
from openai.resources.models import AsyncModels
|
||||||
|
|
||||||
# Store original methods for both OpenAI and Ollama clients
|
# Store original methods for both OpenAI and Ollama clients
|
||||||
_original_methods = {
|
_original_methods = {
|
||||||
|
"model_list": AsyncModels.list,
|
||||||
"chat_completions_create": AsyncChatCompletions.create,
|
"chat_completions_create": AsyncChatCompletions.create,
|
||||||
"completions_create": AsyncCompletions.create,
|
"completions_create": AsyncCompletions.create,
|
||||||
"embeddings_create": AsyncEmbeddings.create,
|
"embeddings_create": AsyncEmbeddings.create,
|
||||||
|
@ -347,6 +365,55 @@ def patch_inference_clients():
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create patched methods for OpenAI client
|
# Create patched methods for OpenAI client
|
||||||
|
def patched_model_list(self, *args, **kwargs):
|
||||||
|
# The original models.list() returns an AsyncPaginator that can be used with async for
|
||||||
|
# We need to create a wrapper that preserves this behavior
|
||||||
|
class PatchedAsyncPaginator:
|
||||||
|
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
|
||||||
|
self.original_method = original_method
|
||||||
|
self.instance = instance
|
||||||
|
self.client_type = client_type
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self._result = None
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
# Make it awaitable like the original AsyncPaginator
|
||||||
|
async def _await():
|
||||||
|
self._result = await _patched_inference_method(
|
||||||
|
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
|
||||||
|
)
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
return _await().__await__()
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
# Make it async iterable like the original AsyncPaginator
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
# Get the result if we haven't already
|
||||||
|
if self._result is None:
|
||||||
|
self._result = await _patched_inference_method(
|
||||||
|
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize iteration on first call
|
||||||
|
if not hasattr(self, "_iter_index"):
|
||||||
|
# Extract the data list from the result
|
||||||
|
self._data_list = self._result.data
|
||||||
|
self._iter_index = 0
|
||||||
|
|
||||||
|
# Return next item from the list
|
||||||
|
if self._iter_index >= len(self._data_list):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
item = self._data_list[self._iter_index]
|
||||||
|
self._iter_index += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
|
||||||
|
|
||||||
async def patched_chat_completions_create(self, *args, **kwargs):
|
async def patched_chat_completions_create(self, *args, **kwargs):
|
||||||
return await _patched_inference_method(
|
return await _patched_inference_method(
|
||||||
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
|
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
|
||||||
|
@ -363,6 +430,7 @@ def patch_inference_clients():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply OpenAI patches
|
# Apply OpenAI patches
|
||||||
|
AsyncModels.list = patched_model_list
|
||||||
AsyncChatCompletions.create = patched_chat_completions_create
|
AsyncChatCompletions.create = patched_chat_completions_create
|
||||||
AsyncCompletions.create = patched_completions_create
|
AsyncCompletions.create = patched_completions_create
|
||||||
AsyncEmbeddings.create = patched_embeddings_create
|
AsyncEmbeddings.create = patched_embeddings_create
|
||||||
|
@ -419,8 +487,10 @@ def unpatch_inference_clients():
|
||||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||||
from openai.resources.completions import AsyncCompletions
|
from openai.resources.completions import AsyncCompletions
|
||||||
from openai.resources.embeddings import AsyncEmbeddings
|
from openai.resources.embeddings import AsyncEmbeddings
|
||||||
|
from openai.resources.models import AsyncModels
|
||||||
|
|
||||||
# Restore OpenAI client methods
|
# Restore OpenAI client methods
|
||||||
|
AsyncModels.list = _original_methods["model_list"]
|
||||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||||
AsyncCompletions.create = _original_methods["completions_create"]
|
AsyncCompletions.create = _original_methods["completions_create"]
|
||||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue