Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-14 11:55:13 -08:00
commit 8d7bb1140f
20 changed files with 381 additions and 414 deletions

View file

@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
delta: ContentDelta
@json_schema_type

View file

@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent(
role="CustomTool", content=chunk.content, color="grey"
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
)
continue
@ -80,14 +87,20 @@ class EventLogger:
):
violation = event.payload.step_details.violation
if not violation:
yield event, LogEvent(
role=step_type, content="No Violation", color="magenta"
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
)
else:
yield event, LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
@ -95,8 +108,11 @@ class EventLogger:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield event, LogEvent(
role=None,
content=event.payload.tool_call_delta.content,
end="",
color="cyan",
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
yield (
event,
LogEvent(
role=None,
content=delta.content,
end="",
color="cyan",
),
)
else:
yield event, LogEvent(
role=None,
content=event.payload.text_delta,
end="",
color="yellow",
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
@ -140,10 +166,13 @@ class EventLogger:
)
else:
content = response.content
yield event, LogEvent(
role=step_type,
content=content,
color="yellow",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
@ -155,16 +184,22 @@ class EventLogger:
):
details = event.payload.step_details
for t in details.tool_calls:
yield event, LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield event, LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if (
@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
inserted_context = interleaved_text_media_as_str(
details.inserted_context
)
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
yield event, LogEvent(
role=step_type,
content=content,
color="cyan",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type

View file

@ -5,10 +5,12 @@
# the root directory of this source tree.
import base64
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, field_serializer, model_validator
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
class TextDelta(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageDelta(BaseModel):
type: Literal["image"] = "image"
data: bytes
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failed = "failed"
succeeded = "succeeded"
@json_schema_type
class ToolCallDelta(BaseModel):
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)

View file

@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None