From a4d8a6009a5a518cb32af71d20db1369a56f936d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 9 Dec 2024 17:14:37 -0800 Subject: [PATCH] Fixes for library client (#587) Library client used _server_ side types which was no bueno. The fix here is not the completely correct fix but it is good for enough and for the demo notebook. --- docs/resources/llama-stack-spec.html | 5 +- docs/resources/llama-stack-spec.yaml | 6 +- llama_stack/apis/agents/agents.py | 3 +- llama_stack/apis/agents/event_logger.py | 2 +- llama_stack/distribution/library_client.py | 153 ++++++++++-------- .../agents/meta_reference/agent_instance.py | 4 +- 6 files changed, 89 insertions(+), 84 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index d1040f186..14e311cfc 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -4368,14 +4368,11 @@ "step_id": { "type": "string" }, - "model_response_text_delta": { + "text_delta": { "type": "string" }, "tool_call_delta": { "$ref": "#/components/schemas/ToolCallDelta" - }, - "tool_response_text_delta": { - "type": "string" } }, "additionalProperties": false, diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 0b737a697..86fcae23d 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -132,8 +132,6 @@ components: const: step_progress default: step_progress type: string - model_response_text_delta: - type: string step_id: type: string step_type: @@ -143,10 +141,10 @@ components: - shield_call - memory_retrieval type: string + text_delta: + type: string tool_call_delta: $ref: '#/components/schemas/ToolCallDelta' - tool_response_text_delta: - type: string required: - event_type - step_type diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 6e41df4f6..575f336af 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel): step_type: StepType step_id: str - model_response_text_delta: Optional[str] = None + text_delta: Optional[str] = None tool_call_delta: Optional[ToolCallDelta] = None - tool_response_text_delta: Optional[str] = None @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 25931b821..737ba385c 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -121,7 +121,7 @@ class EventLogger: else: yield event, LogEvent( role=None, - content=event.payload.model_response_text_delta, + content=event.payload.text_delta, end="", color="yellow", ) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 08c8e2b5d..9265bb560 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -6,16 +6,18 @@ import asyncio import inspect +import json import os import queue import threading from concurrent.futures import ThreadPoolExecutor +from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, TypeVar +from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union import yaml from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN -from pydantic import TypeAdapter +from pydantic import BaseModel, TypeAdapter from rich.console import Console from termcolor import cprint @@ -109,6 +111,65 @@ def stream_across_asyncio_run_boundary( future.result() +def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict: + if isinstance(value, Enum): + return value.value + elif isinstance(value, list): + return [convert_pydantic_to_json_value(item, cast_to) for item in value] + elif isinstance(value, dict): + return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()} + elif isinstance(value, BaseModel): + # This is quite hacky and we should figure out how to use stuff from + # generated client-sdk code (using ApiResponse.parse() essentially) + value_dict = json.loads(value.model_dump_json()) + + origin = get_origin(cast_to) + if origin is Union: + args = get_args(cast_to) + for arg in args: + arg_name = arg.__name__.split(".")[-1] + value_name = value.__class__.__name__.split(".")[-1] + if arg_name == value_name: + return arg(**value_dict) + + # assume we have the correct association between the server-side type and the client-side type + return cast_to(**value_dict) + + return value + + +def convert_to_pydantic(annotation: Any, value: Any) -> Any: + if isinstance(annotation, type) and annotation in {str, int, float, bool}: + return value + + origin = get_origin(annotation) + if origin is list: + item_type = get_args(annotation)[0] + try: + return [convert_to_pydantic(item_type, item) for item in value] + except Exception: + print(f"Error converting list {value}") + return value + + elif origin is dict: + key_type, val_type = get_args(annotation) + try: + return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} + except Exception: + print(f"Error converting dict {value}") + return value + + try: + # Handle Pydantic models and discriminated unions + return TypeAdapter(annotation).validate_python(value) + except Exception as e: + cprint( + f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", + "yellow", + ) + return value + + class LlamaStackAsLibraryClient(LlamaStackClient): def __init__( self, @@ -129,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient): return asyncio.run(self.async_client.initialize()) - def get(self, *args, **kwargs): + def request(self, *args, **kwargs): if kwargs.get("stream"): return stream_across_asyncio_run_boundary( - lambda: self.async_client.get(*args, **kwargs), + lambda: self.async_client.request(*args, **kwargs), self.pool_executor, ) else: - return asyncio.run(self.async_client.get(*args, **kwargs)) - - def post(self, *args, **kwargs): - if kwargs.get("stream"): - return stream_across_asyncio_run_boundary( - lambda: self.async_client.post(*args, **kwargs), - self.pool_executor, - ) - else: - return asyncio.run(self.async_client.post(*args, **kwargs)) + return asyncio.run(self.async_client.request(*args, **kwargs)) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -187,8 +239,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if self.config_path_or_template_name.endswith(".yaml"): print_pip_install_help(self.config.providers) else: + prefix = "!" if in_notebook() else "" cprint( - f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", + f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", "yellow", ) return False @@ -212,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.endpoint_impls = endpoint_impls return True - async def get( + async def request( self, - path: str, + cast_to: Any, + options: Any, *, stream=False, - **kwargs, + stream_cls=None, ): if not self.endpoint_impls: raise ValueError("Client not initialized") + params = options.params or {} + params |= options.json_data or {} if stream: - return self._call_streaming(path, "GET") + return self._call_streaming(options.url, params, cast_to) else: - return await self._call_non_streaming(path, "GET") + return await self._call_non_streaming(options.url, params, cast_to) - async def post( - self, - path: str, - *, - body: dict = None, - stream=False, - **kwargs, + async def _call_non_streaming( + self, path: str, body: dict = None, cast_to: Any = None ): - if not self.endpoint_impls: - raise ValueError("Client not initialized") - - if stream: - return self._call_streaming(path, "POST", body) - else: - return await self._call_non_streaming(path, "POST", body) - - async def _call_non_streaming(self, path: str, method: str, body: dict = None): await start_trace(path, {"__location__": "library_client"}) try: func = self.endpoint_impls.get(path) @@ -251,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError(f"No endpoint found for {path}") body = self._convert_body(path, body) - return await func(**body) + return convert_pydantic_to_json_value(await func(**body), cast_to) finally: await end_trace() - async def _call_streaming(self, path: str, method: str, body: dict = None): + async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): await start_trace(path, {"__location__": "library_client"}) try: func = self.endpoint_impls.get(path) @@ -264,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = self._convert_body(path, body) async for chunk in await func(**body): - yield chunk + yield convert_pydantic_to_json_value(chunk, cast_to) finally: await end_trace() @@ -283,38 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): for param_name, param in sig.parameters.items(): if param_name in body: value = body.get(param_name) - converted_body[param_name] = self._convert_param( + converted_body[param_name] = convert_to_pydantic( param.annotation, value ) return converted_body - - def _convert_param(self, annotation: Any, value: Any) -> Any: - if isinstance(annotation, type) and annotation in {str, int, float, bool}: - return value - - origin = get_origin(annotation) - if origin is list: - item_type = get_args(annotation)[0] - try: - return [self._convert_param(item_type, item) for item in value] - except Exception: - print(f"Error converting list {value}") - return value - - elif origin is dict: - key_type, val_type = get_args(annotation) - try: - return {k: self._convert_param(val_type, v) for k, v in value.items()} - except Exception: - print(f"Error converting dict {value}") - return value - - try: - # Handle Pydantic models and discriminated unions - return TypeAdapter(annotation).validate_python(value) - except Exception as e: - cprint( - f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", - "yellow", - ) - return value diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e367f3c41..126c2e193 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - model_response_text_delta="", + text_delta="", tool_call_delta=delta, ) ) @@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - model_response_text_delta=event.delta, + text_delta=event.delta, ) ) )