mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: direct client pydantic type casting (#1145)
# What does this PR do? - Closes #1142 - Root cause is due to having `Union[str, AgenToolGroupWithArgs]` ## Test Plan - Test with script described in issue. - Print out final converted pydantic object <img width="1470" alt="image" src="https://github.com/user-attachments/assets/15dc9cd0-f37a-4b91-905f-3fe4f59a08c6" /> [//]: # (## Documentation)
This commit is contained in:
parent
8585b95a28
commit
e8cb9e0adb
6 changed files with 26 additions and 13 deletions
|
@ -13,7 +13,7 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, get_args, get_origin
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -81,12 +81,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
|
|
||||||
if origin is list:
|
if origin is list:
|
||||||
item_type = get_args(annotation)[0]
|
item_type = get_args(annotation)[0]
|
||||||
try:
|
try:
|
||||||
return [convert_to_pydantic(item_type, item) for item in value]
|
return [convert_to_pydantic(item_type, item) for item in value]
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error converting list {value}")
|
print(f"Error converting list {value} into {item_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
elif origin is dict:
|
elif origin is dict:
|
||||||
|
@ -94,17 +95,26 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error converting dict {value}")
|
print(f"Error converting dict {value} into {val_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle Pydantic models and discriminated unions
|
# Handle Pydantic models and discriminated unions
|
||||||
return TypeAdapter(annotation).validate_python(value)
|
return TypeAdapter(annotation).validate_python(value)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
# We should get rid of any non-discriminated unions in the API schema.
|
||||||
"yellow",
|
if origin is Union:
|
||||||
)
|
for union_type in get_args(annotation):
|
||||||
|
try:
|
||||||
|
return convert_to_pydantic(union_type, value)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
cprint(
|
||||||
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
|
"yellow",
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@ -421,4 +431,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if param_name in body:
|
if param_name in body:
|
||||||
value = body.get(param_name)
|
value = body.get(param_name)
|
||||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||||
|
|
||||||
return converted_body
|
return converted_body
|
||||||
|
|
|
@ -77,7 +77,7 @@ def typeannotation(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(cls: Type[T]) -> Type[T]:
|
def wrap(cls: Type[T]) -> Type[T]:
|
||||||
setattr(cls, "__repr__", _compact_dataclass_repr)
|
cls.__repr__ = _compact_dataclass_repr
|
||||||
if not dataclasses.is_dataclass(cls):
|
if not dataclasses.is_dataclass(cls):
|
||||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -203,7 +203,7 @@ def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str)
|
||||||
if type_def.default is not dataclasses.MISSING:
|
if type_def.default is not dataclasses.MISSING:
|
||||||
raise TypeError("disallowed: `default` for top-level type definitions")
|
raise TypeError("disallowed: `default` for top-level type definitions")
|
||||||
|
|
||||||
setattr(type_def.type, "__module__", module.__name__)
|
type_def.type.__module__ = module.__name__
|
||||||
setattr(module, type_name, type_def.type)
|
setattr(module, type_name, type_def.type)
|
||||||
|
|
||||||
return node_to_typedef(module, class_name, top_node).type
|
return node_to_typedef(module, class_name, top_node).type
|
||||||
|
|
|
@ -325,7 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
|
||||||
f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}"
|
f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data))
|
return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False))
|
||||||
|
|
||||||
|
|
||||||
class UnionDeserializer(Deserializer):
|
class UnionDeserializer(Deserializer):
|
||||||
|
|
|
@ -263,8 +263,8 @@ def extend_enum(
|
||||||
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
|
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
|
||||||
|
|
||||||
# assign the newly created type to the same module where the extending class is defined
|
# assign the newly created type to the same module where the extending class is defined
|
||||||
setattr(enum_class, "__module__", extend.__module__)
|
enum_class.__module__ = extend.__module__
|
||||||
setattr(enum_class, "__doc__", extend.__doc__)
|
enum_class.__doc__ = extend.__doc__
|
||||||
setattr(sys.modules[extend.__module__], extend.__name__, enum_class)
|
setattr(sys.modules[extend.__module__], extend.__name__, enum_class)
|
||||||
|
|
||||||
return enum.unique(enum_class)
|
return enum.unique(enum_class)
|
||||||
|
@ -874,6 +874,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
|
||||||
for tuple_item_type, item in zip(
|
for tuple_item_type, item in zip(
|
||||||
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
|
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
|
||||||
(item for item in obj),
|
(item for item in obj),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif origin_type is Union:
|
elif origin_type is Union:
|
||||||
|
@ -954,6 +955,7 @@ class RecursiveChecker:
|
||||||
for tuple_item_type, item in zip(
|
for tuple_item_type, item in zip(
|
||||||
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
|
(tuple_item_type for tuple_item_type in typing.get_args(typ)),
|
||||||
(item for item in obj),
|
(item for item in obj),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif origin_type is Union:
|
elif origin_type is Union:
|
||||||
|
|
|
@ -216,7 +216,7 @@ class TypedTupleSerializer(Serializer[tuple]):
|
||||||
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
|
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
|
||||||
|
|
||||||
def generate(self, obj: tuple) -> List[JsonType]:
|
def generate(self, obj: tuple) -> List[JsonType]:
|
||||||
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)]
|
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
|
||||||
|
|
||||||
|
|
||||||
class CustomSerializer(Serializer):
|
class CustomSerializer(Serializer):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue