mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: direct client pydantic type casting (#1145)
# What does this PR do? - Closes #1142 - Root cause is due to having `Union[str, AgenToolGroupWithArgs]` ## Test Plan - Test with script described in issue. - Print out final converted pydantic object <img width="1470" alt="image" src="https://github.com/user-attachments/assets/15dc9cd0-f37a-4b91-905f-3fe4f59a08c6" /> [//]: # (## Documentation)
This commit is contained in:
parent
8585b95a28
commit
e8cb9e0adb
6 changed files with 26 additions and 13 deletions
|
@ -13,7 +13,7 @@ import re
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, TypeVar, get_args, get_origin
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
@ -81,12 +81,13 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
print(f"Error converting list {value} into {item_type}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
|
@ -94,17 +95,26 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
print(f"Error converting dict {value} into {val_type}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
# TODO: this is workardound for having Union[str, AgentToolGroup] in API schema.
|
||||
# We should get rid of any non-discriminated unions in the API schema.
|
||||
if origin is Union:
|
||||
for union_type in get_args(annotation):
|
||||
try:
|
||||
return convert_to_pydantic(union_type, value)
|
||||
except Exception:
|
||||
continue
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
|
@ -421,4 +431,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
|
||||
return converted_body
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue