move all implementations to use updated type

This commit is contained in:
Ashwin Bharambe 2025-01-13 20:04:19 -08:00
parent aced2ce07e
commit 9a5803a429
8 changed files with 139 additions and 208 deletions

View file

@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
SamplingParams, SamplingParams,
ToolCall, ToolCall,
ToolCallDelta,
ToolChoice, ToolChoice,
ToolPromptFormat, ToolPromptFormat,
ToolResponse, ToolResponse,
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType step_type: StepType
step_id: str step_id: str
text_delta: Optional[str] = None delta: ContentDelta
tool_call_delta: Optional[ToolCallDelta] = None
@json_schema_type @json_schema_type

View file

@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent: class LogEvent:
def __init__( def __init__(
@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead # since it does not produce event but instead
# a Message # a Message
if isinstance(chunk, ToolResponseMessage): if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent( yield (
role="CustomTool", content=chunk.content, color="grey" chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
) )
continue continue
@ -80,14 +87,20 @@ class EventLogger:
): ):
violation = event.payload.step_details.violation violation = event.payload.step_details.violation
if not violation: if not violation:
yield event, LogEvent( yield (
role=step_type, content="No Violation", color="magenta" event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
) )
else: else:
yield event, LogEvent( yield (
role=step_type, event,
content=f"{violation.metadata} {violation.user_message}", LogEvent(
color="red", role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
) )
# handle inference # handle inference
@ -95,8 +108,11 @@ class EventLogger:
if stream: if stream:
if event_type == EventType.step_start.value: if event_type == EventType.step_start.value:
# TODO: Currently this event is never received # TODO: Currently this event is never received
yield event, LogEvent( yield (
role=step_type, content="", end="", color="yellow" event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
) )
elif event_type == EventType.step_progress.value: elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress # HACK: if previous was not step/event was not inference's step_progress
@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference and previous_step_type != StepType.inference
): ):
yield event, LogEvent( yield (
role=step_type, content="", end="", color="yellow" event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
) )
if event.payload.tool_call_delta: delta = event.payload.delta
if isinstance(event.payload.tool_call_delta.content, str): if delta.type == "tool_call":
yield event, LogEvent( if delta.parse_status == ToolCallParseStatus.success:
role=None, yield (
content=event.payload.tool_call_delta.content, event,
end="", LogEvent(
color="cyan", role=None,
content=delta.content,
end="",
color="cyan",
),
) )
else: else:
yield event, LogEvent( yield (
role=None, event,
content=event.payload.text_delta, LogEvent(
end="", role=None,
color="yellow", content=delta.text,
end="",
color="yellow",
),
) )
else: else:
# step_complete # step_complete
@ -140,10 +166,13 @@ class EventLogger:
) )
else: else:
content = response.content content = response.content
yield event, LogEvent( yield (
role=step_type, event,
content=content, LogEvent(
color="yellow", role=step_type,
content=content,
color="yellow",
),
) )
# handle tool_execution # handle tool_execution
@ -155,16 +184,22 @@ class EventLogger:
): ):
details = event.payload.step_details details = event.payload.step_details
for t in details.tool_calls: for t in details.tool_calls:
yield event, LogEvent( yield (
role=step_type, event,
content=f"Tool:{t.tool_name} Args:{t.arguments}", LogEvent(
color="green", role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
) )
for r in details.tool_responses: for r in details.tool_responses:
yield event, LogEvent( yield (
role=step_type, event,
content=f"Tool:{r.tool_name} Response:{r.content}", LogEvent(
color="green", role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
) )
if ( if (
@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value and event_type == EventType.step_complete.value
): ):
details = event.payload.step_details details = event.payload.step_details
inserted_context = interleaved_text_media_as_str( inserted_context = interleaved_content_as_str(details.inserted_context)
details.inserted_context
)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}" content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
yield event, LogEvent( yield (
role=step_type, event,
content=content, LogEvent(
color="cyan", role=step_type,
content=content,
color="cyan",
),
) )
previous_event_type = event_type previous_event_type = event_type

View file

@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep, ToolExecutionStep,
Turn, Turn,
) )
from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.common.content_types import (
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
CompletionMessage, CompletionMessage,
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition, ToolDefinition,
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_id=step_id, step_id=step_id,
tool_call_delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.success,
content=ToolCall( content=ToolCall(
call_id="", call_id="",
@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin):
continue continue
delta = event.delta delta = event.delta
if isinstance(delta, ToolCallDelta): if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success: if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content) tool_calls.append(delta.content)
if stream: if stream:
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
text_delta="", delta=delta,
tool_call_delta=delta,
) )
) )
) )
elif isinstance(delta, str): elif delta.type == "text":
content += delta content += delta.text
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
text_delta=event.delta, delta=delta,
) )
) )
) )

View file

@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message, Message,
ResponseFormat, ResponseFormat,
TokenLogProbs, TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice, ToolChoice,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
] ]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=TextDelta(text=text),
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )
if stop_reason is None: if stop_reason is None:
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens, stop_reason=StopReason.out_of_tokens,
) )
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
) )
else: else:
delta = text delta = TextDelta(text=text)
if stop_reason is None: if stop_reason is None:
if request.logprobs: if request.logprobs:
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )

View file

@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
Message, Message,
StopReason, StopReason,
ToolCall, ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -162,7 +165,7 @@ def convert_chat_completion_response(
def _map_finish_reason_to_stop_reason( def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"] finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason: ) -> StopReason:
""" """
Convert a Groq chat completion finish_reason to a StopReason. Convert a Groq chat completion finish_reason to a StopReason.
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
async def convert_chat_completion_response_stream( async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk], stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
event_type = ChatCompletionResponseEventType.start event_type = ChatCompletionResponseEventType.start
for chunk in stream: for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "", delta=TextDelta(text=choice.delta.content or ""),
logprobs=None, logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
) )
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
delta=choice.delta.content or "", delta=TextDelta(text=choice.delta.content or ""),
logprobs=None, logprobs=None,
) )
) )

View file

@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
Message, Message,
SystemMessage, SystemMessage,
TokenLogProbs, TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
""" """
Convert a stream of OpenAI chat completion chunks into a stream Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk. of ChatCompletionResponseStreamChunk.
OpenAI ChatCompletionChunk:
choices: List[Choice]
OpenAI Choice: # different from the non-streamed Choice
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
logprobs: Optional[ChoiceLogprobs]
OpenAI ChoiceDelta:
content: Optional[str]
role: Optional[Literal["system", "user", "assistant", "tool"]]
tool_calls: Optional[List[ChoiceDeltaToolCall]]
OpenAI ChoiceDeltaToolCall:
index: int
id: Optional[str]
function: Optional[ChoiceDeltaToolCallFunction]
type: Optional[Literal["function"]]
OpenAI ChoiceDeltaToolCallFunction:
name: Optional[str]
arguments: Optional[str]
->
ChatCompletionResponseStreamChunk:
event: ChatCompletionResponseEvent
ChatCompletionResponseEvent:
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]]
stop_reason: Optional[StopReason]
ChatCompletionResponseEventType:
start = "start"
progress = "progress"
complete = "complete"
ToolCallDelta:
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
ToolCall:
call_id: str
tool_name: str
arguments: str
ToolCallParseStatus:
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
StopReason:
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
""" """
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=next(event_type),
delta=choice.delta.content, delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),
) )
) )
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=next(event_type),
delta=choice.delta.content or "", # content is not optional delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),
) )
) )
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]: ) -> Optional[List[TokenLogProbs]]:
""" """
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
OpenAI CompletionLogprobs:
text_offset: Optional[List[int]]
token_logprobs: Optional[List[float]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Dict[str, float]]]
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
""" """
if not logprobs: if not logprobs:
return None return None
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
) -> CompletionResponse: ) -> CompletionResponse:
""" """
Convert an OpenAI Completion Choice into a CompletionResponse. Convert an OpenAI Completion Choice into a CompletionResponse.
OpenAI Completion Choice:
text: str
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
->
CompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
""" """
return CompletionResponse( return CompletionResponse(
content=choice.text, content=choice.text,
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
""" """
Convert a stream of OpenAI Completions into a stream Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks. of ChatCompletionResponseStreamChunks.
OpenAI Completion:
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
system_fingerprint: Optional[str]
usage: Optional[OpenAICompletionUsage]
OpenAI CompletionChoice:
finish_reason: str
index: int
logprobs: Optional[OpenAILogprobs]
text: str
->
CompletionResponseStreamChunk:
delta: str
stop_reason: Optional[StopReason]
logprobs: Optional[List[TokenLogProbs]]
""" """
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=choice.text, delta=TextDelta(text=choice.text),
stop_reason=_convert_openai_finish_reason(choice.finish_reason), stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs), logprobs=_convert_openai_completion_logprobs(choice.logprobs),
) )

View file

@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
SystemMessage, SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )
@ -196,7 +195,9 @@ class TestInference:
1 <= len(chunks) <= 6 1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks: for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty" assert chunk.logprobs, "Logprobs should not be empty"
assert all( assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
@ -463,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model: if "Llama3.1" in inference_model:
assert all( assert all(
isinstance(chunk.event.delta, ToolCallDelta) chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress] for chunk in grouped[ChatCompletionResponseEventType.progress]
) )
first = grouped[ChatCompletionResponseEventType.progress][0] first = grouped[ChatCompletionResponseEventType.progress][0]
@ -475,7 +476,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall) assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content call = last.event.delta.content
assert call.tool_name == "get_weather" assert call.tool_name == "get_weather"

View file

@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
Message, Message,
ToolCallDelta,
ToolCallParseStatus,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -138,7 +142,7 @@ async def process_completion_stream_response(
text = "" text = ""
continue continue
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=TextDelta(text=text),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
if finish_reason: if finish_reason:
@ -149,7 +153,7 @@ async def process_completion_stream_response(
break break
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=text, delta=TextDelta(text=text),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )