forked from phoenix-oss/llama-stack-mirror
Fixes for library client (#587)
Library client used _server_ side types which was no bueno. The fix here is not the completely correct fix but it is good for enough and for the demo notebook.
This commit is contained in:
parent
7615da78b8
commit
a4d8a6009a
6 changed files with 89 additions and 84 deletions
|
@ -4368,14 +4368,11 @@
|
|||
"step_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_response_text_delta": {
|
||||
"text_delta": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_call_delta": {
|
||||
"$ref": "#/components/schemas/ToolCallDelta"
|
||||
},
|
||||
"tool_response_text_delta": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
|
@ -132,8 +132,6 @@ components:
|
|||
const: step_progress
|
||||
default: step_progress
|
||||
type: string
|
||||
model_response_text_delta:
|
||||
type: string
|
||||
step_id:
|
||||
type: string
|
||||
step_type:
|
||||
|
@ -143,10 +141,10 @@ components:
|
|||
- shield_call
|
||||
- memory_retrieval
|
||||
type: string
|
||||
text_delta:
|
||||
type: string
|
||||
tool_call_delta:
|
||||
$ref: '#/components/schemas/ToolCallDelta'
|
||||
tool_response_text_delta:
|
||||
type: string
|
||||
required:
|
||||
- event_type
|
||||
- step_type
|
||||
|
|
|
@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -121,7 +121,7 @@ class EventLogger:
|
|||
else:
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.model_response_text_delta,
|
||||
content=event.payload.text_delta,
|
||||
end="",
|
||||
color="yellow",
|
||||
)
|
||||
|
|
|
@ -6,16 +6,18 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
@ -109,6 +111,65 @@ def stream_across_asyncio_run_boundary(
|
|||
future.result()
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
# This is quite hacky and we should figure out how to use stuff from
|
||||
# generated client-sdk code (using ApiResponse.parse() essentially)
|
||||
value_dict = json.loads(value.model_dump_json())
|
||||
|
||||
origin = get_origin(cast_to)
|
||||
if origin is Union:
|
||||
args = get_args(cast_to)
|
||||
for arg in args:
|
||||
arg_name = arg.__name__.split(".")[-1]
|
||||
value_name = value.__class__.__name__.split(".")[-1]
|
||||
if arg_name == value_name:
|
||||
return arg(**value_dict)
|
||||
|
||||
# assume we have the correct association between the server-side type and the client-side type
|
||||
return cast_to(**value_dict)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -129,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
def request(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.get(*args, **kwargs),
|
||||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.get(*args, **kwargs))
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.post(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.post(*args, **kwargs))
|
||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
|
@ -187,8 +239,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
print_pip_install_help(self.config.providers)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
"yellow",
|
||||
)
|
||||
return False
|
||||
|
@ -212,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
self.endpoint_impls = endpoint_impls
|
||||
return True
|
||||
|
||||
async def get(
|
||||
async def request(
|
||||
self,
|
||||
path: str,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
stream_cls=None,
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
params = options.params or {}
|
||||
params |= options.json_data or {}
|
||||
if stream:
|
||||
return self._call_streaming(path, "GET")
|
||||
return self._call_streaming(options.url, params, cast_to)
|
||||
else:
|
||||
return await self._call_non_streaming(path, "GET")
|
||||
return await self._call_non_streaming(options.url, params, cast_to)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: dict = None,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
async def _call_non_streaming(
|
||||
self, path: str, body: dict = None, cast_to: Any = None
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if stream:
|
||||
return self._call_streaming(path, "POST", body)
|
||||
else:
|
||||
return await self._call_non_streaming(path, "POST", body)
|
||||
|
||||
async def _call_non_streaming(self, path: str, method: str, body: dict = None):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
|
@ -251,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return await func(**body)
|
||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
async def _call_streaming(self, path: str, method: str, body: dict = None):
|
||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
|
@ -264,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield chunk
|
||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
@ -283,38 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = self._convert_param(
|
||||
converted_body[param_name] = convert_to_pydantic(
|
||||
param.annotation, value
|
||||
)
|
||||
return converted_body
|
||||
|
||||
def _convert_param(self, annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [self._convert_param(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
return value
|
||||
|
|
|
@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
)
|
||||
)
|
||||
|
@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
text_delta=event.delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue