mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Fixes for library client
This commit is contained in:
parent
7615da78b8
commit
4904ca64df
6 changed files with 32 additions and 21 deletions
|
@ -4368,14 +4368,11 @@
|
|||
"step_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_response_text_delta": {
|
||||
"text_delta": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_call_delta": {
|
||||
"$ref": "#/components/schemas/ToolCallDelta"
|
||||
},
|
||||
"tool_response_text_delta": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
|
@ -132,8 +132,6 @@ components:
|
|||
const: step_progress
|
||||
default: step_progress
|
||||
type: string
|
||||
model_response_text_delta:
|
||||
type: string
|
||||
step_id:
|
||||
type: string
|
||||
step_type:
|
||||
|
@ -143,10 +141,10 @@ components:
|
|||
- shield_call
|
||||
- memory_retrieval
|
||||
type: string
|
||||
text_delta:
|
||||
type: string
|
||||
tool_call_delta:
|
||||
$ref: '#/components/schemas/ToolCallDelta'
|
||||
tool_response_text_delta:
|
||||
type: string
|
||||
required:
|
||||
- event_type
|
||||
- step_type
|
||||
|
|
|
@ -340,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -121,7 +121,7 @@ class EventLogger:
|
|||
else:
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.model_response_text_delta,
|
||||
content=event.payload.text_delta,
|
||||
end="",
|
||||
color="yellow",
|
||||
)
|
||||
|
|
|
@ -6,16 +6,18 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
@ -109,6 +111,18 @@ def stream_across_asyncio_run_boundary(
|
|||
future.result()
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any) -> dict:
|
||||
if isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -187,8 +201,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
print_pip_install_help(self.config.providers)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\nllama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
"yellow",
|
||||
)
|
||||
return False
|
||||
|
@ -251,7 +266,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return await func(**body)
|
||||
return convert_pydantic_to_json_value(await func(**body))
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
@ -264,7 +279,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield chunk
|
||||
yield convert_pydantic_to_json_value(chunk)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
@ -283,12 +298,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = self._convert_param(
|
||||
converted_body[param_name] = self._convert_to_pydantic(
|
||||
param.annotation, value
|
||||
)
|
||||
return converted_body
|
||||
|
||||
def _convert_param(self, annotation: Any, value: Any) -> Any:
|
||||
def _convert_to_pydantic(self, annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
|
@ -296,7 +311,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [self._convert_param(item_type, item) for item in value]
|
||||
return [self._convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
@ -304,7 +319,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: self._convert_param(val_type, v) for k, v in value.items()}
|
||||
return {
|
||||
k: self._convert_to_pydantic(val_type, v) for k, v in value.items()
|
||||
}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
|
|
@ -451,7 +451,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
)
|
||||
)
|
||||
|
@ -465,7 +465,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
text_delta=event.delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue