move all implementations to use updated type

This commit is contained in:
Ashwin Bharambe 2025-01-13 20:04:19 -08:00
parent aced2ce07e
commit 9a5803a429
8 changed files with 139 additions and 208 deletions

View file

@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
delta: ContentDelta
@json_schema_type

View file

@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent(
role="CustomTool", content=chunk.content, color="grey"
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
)
continue
@ -80,14 +87,20 @@ class EventLogger:
):
violation = event.payload.step_details.violation
if not violation:
yield event, LogEvent(
role=step_type, content="No Violation", color="magenta"
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
)
else:
yield event, LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
@ -95,8 +108,11 @@ class EventLogger:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield event, LogEvent(
role=None,
content=event.payload.tool_call_delta.content,
end="",
color="cyan",
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success:
yield (
event,
LogEvent(
role=None,
content=delta.content,
end="",
color="cyan",
),
)
else:
yield event, LogEvent(
role=None,
content=event.payload.text_delta,
end="",
color="yellow",
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
@ -140,10 +166,13 @@ class EventLogger:
)
else:
content = response.content
yield event, LogEvent(
role=step_type,
content=content,
color="yellow",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
@ -155,16 +184,22 @@ class EventLogger:
):
details = event.payload.step_details
for t in details.tool_calls:
yield event, LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield event, LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if (
@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
inserted_context = interleaved_text_media_as_str(
details.inserted_context
)
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
yield event, LogEvent(
role=step_type,
content=content,
color="cyan",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type