move all implementations to use updated type

This commit is contained in:
Ashwin Bharambe 2025-01-13 20:04:19 -08:00
parent aced2ce07e
commit 9a5803a429
8 changed files with 139 additions and 208 deletions

View file

@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
text_delta: Optional[str] = None
tool_call_delta: Optional[ToolCallDelta] = None
delta: ContentDelta
@json_schema_type

View file

@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent(
role="CustomTool", content=chunk.content, color="grey"
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
)
continue
@ -80,14 +87,20 @@ class EventLogger:
):
violation = event.payload.step_details.violation
if not violation:
yield event, LogEvent(
role=step_type, content="No Violation", color="magenta"
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
)
else:
yield event, LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
@ -95,8 +108,11 @@ class EventLogger:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield event, LogEvent(
role=step_type, content="", end="", color="yellow"
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
)
if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield event, LogEvent(
role=None,
content=event.payload.tool_call_delta.content,
end="",
color="cyan",
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success:
yield (
event,
LogEvent(
role=None,
content=delta.content,
end="",
color="cyan",
),
)
else:
yield event, LogEvent(
role=None,
content=event.payload.text_delta,
end="",
color="yellow",
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
@ -140,10 +166,13 @@ class EventLogger:
)
else:
content = response.content
yield event, LogEvent(
role=step_type,
content=content,
color="yellow",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
@ -155,16 +184,22 @@ class EventLogger:
):
details = event.payload.step_details
for t in details.tool_calls:
yield event, LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield event, LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if (
@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
inserted_context = interleaved_text_media_as_str(
details.inserted_context
)
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
yield event, LogEvent(
role=step_type,
content=content,
color="cyan",
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type

View file

@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.common.content_types import (
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin):
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta="",
tool_call_delta=delta,
delta=delta,
)
)
)
elif isinstance(delta, str):
content += delta
elif delta.type == "text":
content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta=event.delta,
delta=delta,
)
)
)

View file

@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
]
yield CompletionResponseStreamChunk(
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens,
)
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
delta=TextDelta(text=""),
)
)
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)

View file

@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
Message,
StopReason,
ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolPromptFormat,
)
@ -162,7 +165,7 @@ def convert_chat_completion_response(
def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"]
finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
event_type = ChatCompletionResponseEventType.start
for chunk in stream:
choice = chunk.choices[0]
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "",
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)

View file

@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
Message,
SystemMessage,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage,
UserMessage,
)
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
OpenAI ChatCompletionChunk:
choices: List[Choice]
OpenAI Choice: # different from the non-streamed Choice
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
logprobs: Optional[ChoiceLogprobs]
OpenAI ChoiceDelta:
content: Optional[str]
role: Optional[Literal["system", "user", "assistant", "tool"]]
tool_calls: Optional[List[ChoiceDeltaToolCall]]
OpenAI ChoiceDeltaToolCall:
index: int
id: Optional[str]
function: Optional[ChoiceDeltaToolCallFunction]
type: Optional[Literal["function"]]
OpenAI ChoiceDeltaToolCallFunction:
name: Optional[str]
arguments: Optional[str]
->
ChatCompletionResponseStreamChunk:
event: ChatCompletionResponseEvent
ChatCompletionResponseEvent:
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]]
stop_reason: Optional[StopReason]
ChatCompletionResponseEventType:
start = "start"
progress = "progress"
complete = "complete"
ToolCallDelta:
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
ToolCall:
call_id: str
tool_name: str
arguments: str
ToolCallParseStatus:
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
StopReason:
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content,
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content or "", # content is not optional
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
OpenAI CompletionLogprobs:
text_offset: Optional[List[int]]
token_logprobs: Optional[List[float]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Dict[str, float]]]
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
OpenAI Completion Choice:
text: str
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
->
CompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
return CompletionResponse(
content=choice.text,
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
OpenAI Completion:
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
system_fingerprint: Optional[str]
usage: Optional[OpenAICompletionUsage]
OpenAI CompletionChoice:
finish_reason: str
index: int
logprobs: Optional[OpenAILogprobs]
text: str
->
CompletionResponseStreamChunk:
delta: str
stop_reason: Optional[StopReason]
logprobs: Optional[List[TokenLogProbs]]
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=choice.text,
delta=TextDelta(text=choice.text),
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)

View file

@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
@ -196,7 +195,9 @@ class TestInference:
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
@ -463,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
@ -475,7 +476,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content
assert call.tool_name == "get_weather"

View file

@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
Message,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -138,7 +142,7 @@ async def process_completion_stream_response(
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
)
if finish_reason:
@ -149,7 +153,7 @@ async def process_completion_stream_response(
break
yield CompletionResponseStreamChunk(
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
delta=TextDelta(text=""),
)
)
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
)
)
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)