diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index d1040f186..14e311cfc 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -4368,14 +4368,11 @@
"step_id": {
"type": "string"
},
- "model_response_text_delta": {
+ "text_delta": {
"type": "string"
},
"tool_call_delta": {
"$ref": "#/components/schemas/ToolCallDelta"
- },
- "tool_response_text_delta": {
- "type": "string"
}
},
"additionalProperties": false,
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 0b737a697..86fcae23d 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -132,8 +132,6 @@ components:
const: step_progress
default: step_progress
type: string
- model_response_text_delta:
- type: string
step_id:
type: string
step_type:
@@ -143,10 +141,10 @@ components:
- shield_call
- memory_retrieval
type: string
+ text_delta:
+ type: string
tool_call_delta:
$ref: '#/components/schemas/ToolCallDelta'
- tool_response_text_delta:
- type: string
required:
- event_type
- step_type
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index 6e41df4f6..575f336af 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
- model_response_text_delta: Optional[str] = None
+ text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
- tool_response_text_delta: Optional[str] = None
@json_schema_type
diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py
index 25931b821..737ba385c 100644
--- a/llama_stack/apis/agents/event_logger.py
+++ b/llama_stack/apis/agents/event_logger.py
@@ -121,7 +121,7 @@ class EventLogger:
else:
yield event, LogEvent(
role=None,
- content=event.payload.model_response_text_delta,
+ content=event.payload.text_delta,
end="",
color="yellow",
)
diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py
index 08c8e2b5d..9265bb560 100644
--- a/llama_stack/distribution/library_client.py
+++ b/llama_stack/distribution/library_client.py
@@ -6,16 +6,18 @@
import asyncio
import inspect
+import json
import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
+from enum import Enum
from pathlib import Path
-from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
+from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
-from pydantic import TypeAdapter
+from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
@@ -109,6 +111,65 @@ def stream_across_asyncio_run_boundary(
future.result()
+def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
+ if isinstance(value, Enum):
+ return value.value
+ elif isinstance(value, list):
+ return [convert_pydantic_to_json_value(item, cast_to) for item in value]
+ elif isinstance(value, dict):
+ return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
+ elif isinstance(value, BaseModel):
+ # This is quite hacky and we should figure out how to use stuff from
+ # generated client-sdk code (using ApiResponse.parse() essentially)
+ value_dict = json.loads(value.model_dump_json())
+
+ origin = get_origin(cast_to)
+ if origin is Union:
+ args = get_args(cast_to)
+ for arg in args:
+ arg_name = arg.__name__.split(".")[-1]
+ value_name = value.__class__.__name__.split(".")[-1]
+ if arg_name == value_name:
+ return arg(**value_dict)
+
+ # assume we have the correct association between the server-side type and the client-side type
+ return cast_to(**value_dict)
+
+ return value
+
+
+def convert_to_pydantic(annotation: Any, value: Any) -> Any:
+ if isinstance(annotation, type) and annotation in {str, int, float, bool}:
+ return value
+
+ origin = get_origin(annotation)
+ if origin is list:
+ item_type = get_args(annotation)[0]
+ try:
+ return [convert_to_pydantic(item_type, item) for item in value]
+ except Exception:
+ print(f"Error converting list {value}")
+ return value
+
+ elif origin is dict:
+ key_type, val_type = get_args(annotation)
+ try:
+ return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
+ except Exception:
+ print(f"Error converting dict {value}")
+ return value
+
+ try:
+ # Handle Pydantic models and discriminated unions
+ return TypeAdapter(annotation).validate_python(value)
+ except Exception as e:
+ cprint(
+ f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
+ "yellow",
+ )
+ return value
+
+
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
@@ -129,23 +190,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return asyncio.run(self.async_client.initialize())
- def get(self, *args, **kwargs):
+ def request(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
- lambda: self.async_client.get(*args, **kwargs),
+ lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
)
else:
- return asyncio.run(self.async_client.get(*args, **kwargs))
-
- def post(self, *args, **kwargs):
- if kwargs.get("stream"):
- return stream_across_asyncio_run_boundary(
- lambda: self.async_client.post(*args, **kwargs),
- self.pool_executor,
- )
- else:
- return asyncio.run(self.async_client.post(*args, **kwargs))
+ return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@@ -187,8 +239,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
else:
+ prefix = "!" if in_notebook() else ""
cprint(
- f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
+ f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow",
)
return False
@@ -212,38 +265,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = endpoint_impls
return True
- async def get(
+ async def request(
self,
- path: str,
+ cast_to: Any,
+ options: Any,
*,
stream=False,
- **kwargs,
+ stream_cls=None,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
+ params = options.params or {}
+ params |= options.json_data or {}
if stream:
- return self._call_streaming(path, "GET")
+ return self._call_streaming(options.url, params, cast_to)
else:
- return await self._call_non_streaming(path, "GET")
+ return await self._call_non_streaming(options.url, params, cast_to)
- async def post(
- self,
- path: str,
- *,
- body: dict = None,
- stream=False,
- **kwargs,
+ async def _call_non_streaming(
+ self, path: str, body: dict = None, cast_to: Any = None
):
- if not self.endpoint_impls:
- raise ValueError("Client not initialized")
-
- if stream:
- return self._call_streaming(path, "POST", body)
- else:
- return await self._call_non_streaming(path, "POST", body)
-
- async def _call_non_streaming(self, path: str, method: str, body: dict = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@@ -251,11 +293,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
- return await func(**body)
+ return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
- async def _call_streaming(self, path: str, method: str, body: dict = None):
+ async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
@@ -264,7 +306,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body)
async for chunk in await func(**body):
- yield chunk
+ yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
@@ -283,38 +325,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
- converted_body[param_name] = self._convert_param(
+ converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
return converted_body
-
- def _convert_param(self, annotation: Any, value: Any) -> Any:
- if isinstance(annotation, type) and annotation in {str, int, float, bool}:
- return value
-
- origin = get_origin(annotation)
- if origin is list:
- item_type = get_args(annotation)[0]
- try:
- return [self._convert_param(item_type, item) for item in value]
- except Exception:
- print(f"Error converting list {value}")
- return value
-
- elif origin is dict:
- key_type, val_type = get_args(annotation)
- try:
- return {k: self._convert_param(val_type, v) for k, v in value.items()}
- except Exception:
- print(f"Error converting dict {value}")
- return value
-
- try:
- # Handle Pydantic models and discriminated unions
- return TypeAdapter(annotation).validate_python(value)
- except Exception as e:
- cprint(
- f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
- "yellow",
- )
- return value
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index e367f3c41..126c2e193 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
- model_response_text_delta="",
+ text_delta="",
tool_call_delta=delta,
)
)
@@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
- model_response_text_delta=event.delta,
+ text_delta=event.delta,
)
)
)