Make sure we _really_ fix library client with client-side types

This commit is contained in:
Ashwin Bharambe 2024-12-09 17:01:35 -08:00
parent 4904ca64df
commit 020e175a70

View file

@ -13,7 +13,7 @@ import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
import yaml import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
@ -111,18 +111,65 @@ def stream_across_asyncio_run_boundary(
future.result() future.result()
def convert_pydantic_to_json_value(value: Any) -> dict: def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
if isinstance(value, BaseModel): if isinstance(value, Enum):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value return value.value
elif isinstance(value, list): elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value] return [convert_pydantic_to_json_value(item, cast_to) for item in value]
elif isinstance(value, dict): elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()} return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from
# generated client-sdk code (using ApiResponse.parse() essentially)
value_dict = json.loads(value.model_dump_json())
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
@ -143,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.initialize()) return asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs): def request(self, *args, **kwargs):
if kwargs.get("stream"): if kwargs.get("stream"):
return stream_across_asyncio_run_boundary( return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs), lambda: self.async_client.request(*args, **kwargs),
self.pool_executor, self.pool_executor,
) )
else: else:
return asyncio.run(self.async_client.get(*args, **kwargs)) return asyncio.run(self.async_client.request(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -227,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = endpoint_impls self.endpoint_impls = endpoint_impls
return True return True
async def get( async def request(
self, self,
path: str, cast_to: Any,
options: Any,
*, *,
stream=False, stream=False,
**kwargs, stream_cls=None,
): ):
if not self.endpoint_impls: if not self.endpoint_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream: if stream:
return self._call_streaming(path, "GET") return self._call_streaming(options.url, params, cast_to)
else: else:
return await self._call_non_streaming(path, "GET") return await self._call_non_streaming(options.url, params, cast_to)
async def post( async def _call_non_streaming(
self, self, path: str, body: dict = None, cast_to: Any = None
path: str,
*,
body: dict = None,
stream=False,
**kwargs,
): ):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(path, "POST", body)
else:
return await self._call_non_streaming(path, "POST", body)
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
try: try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
@ -266,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body)) return convert_pydantic_to_json_value(await func(**body), cast_to)
finally: finally:
await end_trace() await end_trace()
async def _call_streaming(self, path: str, method: str, body: dict = None): async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
try: try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
@ -279,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body): async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk) yield convert_pydantic_to_json_value(chunk, cast_to)
finally: finally:
await end_trace() await end_trace()
@ -298,40 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = self._convert_to_pydantic( converted_body[param_name] = convert_to_pydantic(
param.annotation, value param.annotation, value
) )
return converted_body return converted_body
def _convert_to_pydantic(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [self._convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {
k: self._convert_to_pydantic(val_type, v) for k, v in value.items()
}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value