Make sure we _really_ fix library client with client-side types

This commit is contained in:
Ashwin Bharambe 2024-12-09 17:01:35 -08:00
parent 4904ca64df
commit 020e175a70

View file

@ -13,7 +13,7 @@ import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
@ -111,15 +111,62 @@ def stream_across_asyncio_run_boundary(
future.result()
def convert_pydantic_to_json_value(value: Any) -> dict:
if isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from
# generated client-sdk code (using ApiResponse.parse() essentially)
value_dict = json.loads(value.model_dump_json())
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value
@ -143,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
def request(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs),
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -227,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = endpoint_impls
return True
async def get(
async def request(
self,
path: str,
cast_to: Any,
options: Any,
*,
stream=False,
**kwargs,
stream_cls=None,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream:
return self._call_streaming(path, "GET")
return self._call_streaming(options.url, params, cast_to)
else:
return await self._call_non_streaming(path, "GET")
return await self._call_non_streaming(options.url, params, cast_to)
async def post(
self,
path: str,
*,
body: dict = None,
stream=False,
**kwargs,
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(path, "POST", body)
else:
return await self._call_non_streaming(path, "POST", body)
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -266,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body))
return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, method: str, body: dict = None):
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -279,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk)
yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
@ -298,40 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_to_pydantic(
converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
return converted_body
def _convert_to_pydantic(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [self._convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {
k: self._convert_to_pydantic(val_type, v) for k, v in value.items()
}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value