forked from phoenix-oss/llama-stack-mirror
move all implementations to use updated type
This commit is contained in:
parent
aced2ce07e
commit
9a5803a429
8 changed files with 139 additions and 208 deletions
|
@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import ToolResponseMessage
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
|
@ -57,8 +61,11 @@ class EventLogger:
|
|||
# since it does not produce event but instead
|
||||
# a Message
|
||||
if isinstance(chunk, ToolResponseMessage):
|
||||
yield chunk, LogEvent(
|
||||
role="CustomTool", content=chunk.content, color="grey"
|
||||
yield (
|
||||
chunk,
|
||||
LogEvent(
|
||||
role="CustomTool", content=chunk.content, color="grey"
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -80,14 +87,20 @@ class EventLogger:
|
|||
):
|
||||
violation = event.payload.step_details.violation
|
||||
if not violation:
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="No Violation", color="magenta"
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type, content="No Violation", color="magenta"
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"{violation.metadata} {violation.user_message}",
|
||||
color="red",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"{violation.metadata} {violation.user_message}",
|
||||
color="red",
|
||||
),
|
||||
)
|
||||
|
||||
# handle inference
|
||||
|
@ -95,8 +108,11 @@ class EventLogger:
|
|||
if stream:
|
||||
if event_type == EventType.step_start.value:
|
||||
# TODO: Currently this event is never received
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
),
|
||||
)
|
||||
elif event_type == EventType.step_progress.value:
|
||||
# HACK: if previous was not step/event was not inference's step_progress
|
||||
|
@ -107,24 +123,34 @@ class EventLogger:
|
|||
previous_event_type != EventType.step_progress.value
|
||||
and previous_step_type != StepType.inference
|
||||
):
|
||||
yield event, LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type, content="", end="", color="yellow"
|
||||
),
|
||||
)
|
||||
|
||||
if event.payload.tool_call_delta:
|
||||
if isinstance(event.payload.tool_call_delta.content, str):
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.tool_call_delta.content,
|
||||
end="",
|
||||
color="cyan",
|
||||
delta = event.payload.delta
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=None,
|
||||
content=delta.content,
|
||||
end="",
|
||||
color="cyan",
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.text_delta,
|
||||
end="",
|
||||
color="yellow",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=None,
|
||||
content=delta.text,
|
||||
end="",
|
||||
color="yellow",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# step_complete
|
||||
|
@ -140,10 +166,13 @@ class EventLogger:
|
|||
)
|
||||
else:
|
||||
content = response.content
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="yellow",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="yellow",
|
||||
),
|
||||
)
|
||||
|
||||
# handle tool_execution
|
||||
|
@ -155,16 +184,22 @@ class EventLogger:
|
|||
):
|
||||
details = event.payload.step_details
|
||||
for t in details.tool_calls:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||
color="green",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||
color="green",
|
||||
),
|
||||
)
|
||||
for r in details.tool_responses:
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||
color="green",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||
color="green",
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -172,15 +207,16 @@ class EventLogger:
|
|||
and event_type == EventType.step_complete.value
|
||||
):
|
||||
details = event.payload.step_details
|
||||
inserted_context = interleaved_text_media_as_str(
|
||||
details.inserted_context
|
||||
)
|
||||
inserted_context = interleaved_content_as_str(details.inserted_context)
|
||||
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
||||
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="cyan",
|
||||
yield (
|
||||
event,
|
||||
LogEvent(
|
||||
role=step_type,
|
||||
content=content,
|
||||
color="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
previous_event_type = event_type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue