mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
Make sure we _really_ fix library client with client-side types
This commit is contained in:
parent
4904ca64df
commit
020e175a70
1 changed files with 71 additions and 77 deletions
|
@ -13,7 +13,7 @@ import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||||
|
@ -111,18 +111,65 @@ def stream_across_asyncio_run_boundary(
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any) -> dict:
|
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
|
||||||
if isinstance(value, BaseModel):
|
if isinstance(value, Enum):
|
||||||
return json.loads(value.model_dump_json())
|
|
||||||
elif isinstance(value, Enum):
|
|
||||||
return value.value
|
return value.value
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [convert_pydantic_to_json_value(item) for item in value]
|
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
|
||||||
|
elif isinstance(value, BaseModel):
|
||||||
|
# This is quite hacky and we should figure out how to use stuff from
|
||||||
|
# generated client-sdk code (using ApiResponse.parse() essentially)
|
||||||
|
value_dict = json.loads(value.model_dump_json())
|
||||||
|
|
||||||
|
origin = get_origin(cast_to)
|
||||||
|
if origin is Union:
|
||||||
|
args = get_args(cast_to)
|
||||||
|
for arg in args:
|
||||||
|
arg_name = arg.__name__.split(".")[-1]
|
||||||
|
value_name = value.__class__.__name__.split(".")[-1]
|
||||||
|
if arg_name == value_name:
|
||||||
|
return arg(**value_dict)
|
||||||
|
|
||||||
|
# assume we have the correct association between the server-side type and the client-side type
|
||||||
|
return cast_to(**value_dict)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
|
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||||
|
return value
|
||||||
|
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is list:
|
||||||
|
item_type = get_args(annotation)[0]
|
||||||
|
try:
|
||||||
|
return [convert_to_pydantic(item_type, item) for item in value]
|
||||||
|
except Exception:
|
||||||
|
print(f"Error converting list {value}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
key_type, val_type = get_args(annotation)
|
||||||
|
try:
|
||||||
|
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||||
|
except Exception:
|
||||||
|
print(f"Error converting dict {value}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle Pydantic models and discriminated unions
|
||||||
|
return TypeAdapter(annotation).validate_python(value)
|
||||||
|
except Exception as e:
|
||||||
|
cprint(
|
||||||
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
|
"yellow",
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -143,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
return asyncio.run(self.async_client.initialize())
|
return asyncio.run(self.async_client.initialize())
|
||||||
|
|
||||||
def get(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
return stream_across_asyncio_run_boundary(
|
return stream_across_asyncio_run_boundary(
|
||||||
lambda: self.async_client.get(*args, **kwargs),
|
lambda: self.async_client.request(*args, **kwargs),
|
||||||
self.pool_executor,
|
self.pool_executor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return asyncio.run(self.async_client.get(*args, **kwargs))
|
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||||
|
|
||||||
def post(self, *args, **kwargs):
|
|
||||||
if kwargs.get("stream"):
|
|
||||||
return stream_across_asyncio_run_boundary(
|
|
||||||
lambda: self.async_client.post(*args, **kwargs),
|
|
||||||
self.pool_executor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return asyncio.run(self.async_client.post(*args, **kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -227,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.endpoint_impls = endpoint_impls
|
self.endpoint_impls = endpoint_impls
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get(
|
async def request(
|
||||||
self,
|
self,
|
||||||
path: str,
|
cast_to: Any,
|
||||||
|
options: Any,
|
||||||
*,
|
*,
|
||||||
stream=False,
|
stream=False,
|
||||||
**kwargs,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
params = options.params or {}
|
||||||
|
params |= options.json_data or {}
|
||||||
if stream:
|
if stream:
|
||||||
return self._call_streaming(path, "GET")
|
return self._call_streaming(options.url, params, cast_to)
|
||||||
else:
|
else:
|
||||||
return await self._call_non_streaming(path, "GET")
|
return await self._call_non_streaming(options.url, params, cast_to)
|
||||||
|
|
||||||
async def post(
|
async def _call_non_streaming(
|
||||||
self,
|
self, path: str, body: dict = None, cast_to: Any = None
|
||||||
path: str,
|
|
||||||
*,
|
|
||||||
body: dict = None,
|
|
||||||
stream=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
if not self.endpoint_impls:
|
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._call_streaming(path, "POST", body)
|
|
||||||
else:
|
|
||||||
return await self._call_non_streaming(path, "POST", body)
|
|
||||||
|
|
||||||
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
func = self.endpoint_impls.get(path)
|
func = self.endpoint_impls.get(path)
|
||||||
|
@ -266,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
return convert_pydantic_to_json_value(await func(**body))
|
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
func = self.endpoint_impls.get(path)
|
func = self.endpoint_impls.get(path)
|
||||||
|
@ -279,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
yield convert_pydantic_to_json_value(chunk)
|
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
@ -298,40 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
for param_name, param in sig.parameters.items():
|
for param_name, param in sig.parameters.items():
|
||||||
if param_name in body:
|
if param_name in body:
|
||||||
value = body.get(param_name)
|
value = body.get(param_name)
|
||||||
converted_body[param_name] = self._convert_to_pydantic(
|
converted_body[param_name] = convert_to_pydantic(
|
||||||
param.annotation, value
|
param.annotation, value
|
||||||
)
|
)
|
||||||
return converted_body
|
return converted_body
|
||||||
|
|
||||||
def _convert_to_pydantic(self, annotation: Any, value: Any) -> Any:
|
|
||||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
|
||||||
return value
|
|
||||||
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
if origin is list:
|
|
||||||
item_type = get_args(annotation)[0]
|
|
||||||
try:
|
|
||||||
return [self._convert_to_pydantic(item_type, item) for item in value]
|
|
||||||
except Exception:
|
|
||||||
print(f"Error converting list {value}")
|
|
||||||
return value
|
|
||||||
|
|
||||||
elif origin is dict:
|
|
||||||
key_type, val_type = get_args(annotation)
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
k: self._convert_to_pydantic(val_type, v) for k, v in value.items()
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
print(f"Error converting dict {value}")
|
|
||||||
return value
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Handle Pydantic models and discriminated unions
|
|
||||||
return TypeAdapter(annotation).validate_python(value)
|
|
||||||
except Exception as e:
|
|
||||||
cprint(
|
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
|
||||||
"yellow",
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue