Fixes for library client (#587)

Library client used _server_ side types which was no bueno. The fix here
is not the completely correct fix but it is good for enough and for the
demo notebook.
This commit is contained in:
Ashwin Bharambe 2024-12-09 17:14:37 -08:00 committed by GitHub
parent 7615da78b8
commit a4d8a6009a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 89 additions and 84 deletions

View file

@ -4368,14 +4368,11 @@
"step_id": {
"type": "string"
},
"model_response_text_delta": {
"text_delta": {
"type": "string"
},
"tool_call_delta": {
"$ref": "#/components/schemas/ToolCallDelta"
},
"tool_response_text_delta": {
"type": "string"
}
},
"additionalProperties": false,

View file

@ -132,8 +132,6 @@ components:
const: step_progress
default: step_progress
type: string
model_response_text_delta:
type: string
step_id:
type: string
step_type:
@ -143,10 +141,10 @@ components:
- shield_call
- memory_retrieval
type: string
text_delta:
type: string
tool_call_delta:
$ref: '#/components/schemas/ToolCallDelta'
tool_response_text_delta:
type: string
required:
- event_type
- step_type

View file

@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
model_response_text_delta: Optional[str] = None
text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type

View file

@ -121,7 +121,7 @@ class EventLogger:
else:
yield event, LogEvent(
role=None,
content=event.payload.model_response_text_delta,
content=event.payload.text_delta,
end="",
color="yellow",
)

View file

@ -6,16 +6,18 @@
import asyncio
import inspect
import json
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
@ -109,6 +111,65 @@ def stream_across_asyncio_run_boundary(
future.result()
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
if isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
elif isinstance(value, BaseModel):
# This is quite hacky and we should figure out how to use stuff from
# generated client-sdk code (using ApiResponse.parse() essentially)
value_dict = json.loads(value.model_dump_json())
origin = get_origin(cast_to)
if origin is Union:
args = get_args(cast_to)
for arg in args:
arg_name = arg.__name__.split(".")[-1]
value_name = value.__class__.__name__.split(".")[-1]
if arg_name == value_name:
return arg(**value_dict)
# assume we have the correct association between the server-side type and the client-side type
return cast_to(**value_dict)
return value
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
@ -129,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.initialize())
def get(self, *args, **kwargs):
def request(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs),
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.get(*args, **kwargs))
def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))
return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -187,8 +239,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
else:
prefix = "!" if in_notebook() else ""
cprint(
f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
)
return False
@ -212,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = endpoint_impls
return True
async def get(
async def request(
self,
path: str,
cast_to: Any,
options: Any,
*,
stream=False,
**kwargs,
stream_cls=None,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
params = options.params or {}
params |= options.json_data or {}
if stream:
return self._call_streaming(path, "GET")
return self._call_streaming(options.url, params, cast_to)
else:
return await self._call_non_streaming(path, "GET")
return await self._call_non_streaming(options.url, params, cast_to)
async def post(
self,
path: str,
*,
body: dict = None,
stream=False,
**kwargs,
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
if stream:
return self._call_streaming(path, "POST", body)
else:
return await self._call_non_streaming(path, "POST", body)
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -251,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return await func(**body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, method: str, body: dict = None):
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@ -264,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield chunk
yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
@ -283,38 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(
converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
return converted_body
def _convert_param(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value
origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [self._convert_param(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value
elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value
try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value

View file

@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
text_delta="",
tool_call_delta=delta,
)
)
@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
text_delta=event.delta,
)
)
)