Fixes for library client

This commit is contained in:
Ashwin Bharambe 2024-12-09 11:41:36 -08:00
parent 7615da78b8
commit 4904ca64df
6 changed files with 32 additions and 21 deletions

View file

@ -4368,14 +4368,11 @@
"step_id": { "step_id": {
"type": "string" "type": "string"
}, },
"model_response_text_delta": { "text_delta": {
"type": "string" "type": "string"
}, },
"tool_call_delta": { "tool_call_delta": {
"$ref": "#/components/schemas/ToolCallDelta" "$ref": "#/components/schemas/ToolCallDelta"
},
"tool_response_text_delta": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -132,8 +132,6 @@ components:
const: step_progress const: step_progress
default: step_progress default: step_progress
type: string type: string
model_response_text_delta:
type: string
step_id: step_id:
type: string type: string
step_type: step_type:
@ -143,10 +141,10 @@ components:
- shield_call - shield_call
- memory_retrieval - memory_retrieval
type: string type: string
text_delta:
type: string
tool_call_delta: tool_call_delta:
$ref: '#/components/schemas/ToolCallDelta' $ref: '#/components/schemas/ToolCallDelta'
tool_response_text_delta:
type: string
required: required:
- event_type - event_type
- step_type - step_type

View file

@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType step_type: StepType
step_id: str step_id: str
model_response_text_delta: Optional[str] = None text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None tool_call_delta: Optional[ToolCallDelta] = None
tool_response_text_delta: Optional[str] = None
@json_schema_type @json_schema_type

View file

@ -121,7 +121,7 @@ class EventLogger:
else: else:
yield event, LogEvent( yield event, LogEvent(
role=None, role=None,
content=event.payload.model_response_text_delta, content=event.payload.text_delta,
end="", end="",
color="yellow", color="yellow",
) )

View file

@ -6,16 +6,18 @@
import asyncio import asyncio
import inspect import inspect
import json
import os import os
import queue import queue
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import yaml import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter from pydantic import BaseModel, TypeAdapter
from rich.console import Console from rich.console import Console
from termcolor import cprint from termcolor import cprint
@ -109,6 +111,18 @@ def stream_across_asyncio_run_boundary(
future.result() future.result()
def convert_pydantic_to_json_value(value: Any) -> dict:
if isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
elif isinstance(value, list):
return [convert_pydantic_to_json_value(item) for item in value]
elif isinstance(value, dict):
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
return value
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
@ -187,8 +201,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers) print_pip_install_help(self.config.providers)
else: else:
prefix = "!" if in_notebook() else ""
cprint( cprint(
f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow", "yellow",
) )
return False return False
@ -251,7 +266,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return await func(**body) return convert_pydantic_to_json_value(await func(**body))
finally: finally:
await end_trace() await end_trace()
@ -264,7 +279,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body): async for chunk in await func(**body):
yield chunk yield convert_pydantic_to_json_value(chunk)
finally: finally:
await end_trace() await end_trace()
@ -283,12 +298,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = self._convert_param( converted_body[param_name] = self._convert_to_pydantic(
param.annotation, value param.annotation, value
) )
return converted_body return converted_body
def _convert_param(self, annotation: Any, value: Any) -> Any: def _convert_to_pydantic(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}: if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value return value
@ -296,7 +311,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if origin is list: if origin is list:
item_type = get_args(annotation)[0] item_type = get_args(annotation)[0]
try: try:
return [self._convert_param(item_type, item) for item in value] return [self._convert_to_pydantic(item_type, item) for item in value]
except Exception: except Exception:
print(f"Error converting list {value}") print(f"Error converting list {value}")
return value return value
@ -304,7 +319,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
elif origin is dict: elif origin is dict:
key_type, val_type = get_args(annotation) key_type, val_type = get_args(annotation)
try: try:
return {k: self._convert_param(val_type, v) for k, v in value.items()} return {
k: self._convert_to_pydantic(val_type, v) for k, v in value.items()
}
except Exception: except Exception:
print(f"Error converting dict {value}") print(f"Error converting dict {value}")
return value return value

View file

@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta="", text_delta="",
tool_call_delta=delta, tool_call_delta=delta,
) )
) )
@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta=event.delta, text_delta=event.delta,
) )
) )
) )