diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index e648f9a19..8d0ade6a1 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -98,6 +98,8 @@ class ToolCallParseStatus(Enum): @json_schema_type class ToolCallDelta(BaseModel): + """Deprecated: use InProgressToolCallDelta or ParsedToolCallDelta instead""" + type: Literal["tool_call"] = "tool_call" # you either send an in-progress tool call so the client can stream a long @@ -107,10 +109,34 @@ class ToolCallDelta(BaseModel): parse_status: ToolCallParseStatus +@json_schema_type +class InProgressToolCallDelta(BaseModel): + """Delta sent when the tool call is in progress + + :param type: Discriminator type of the content item. Always "in_progress_tool_call" + :param tool_call: The tool call that is currently being executed + """ + + type: Literal["in_progress_tool_call"] = "in_progress_tool_call" + tool_call: str + + +@json_schema_type +class FinalToolCallDelta(BaseModel): + """Delta sent when the tool call is complete + + :param type: Discriminator type of the content item. Always "final_tool_call" + :param tool_call: The final parsed tool call. If the tool call failed to parse, this will be None. + """ + + type: Literal["final_tool_call"] = "final_tool_call" + tool_call: Optional[ToolCall] + + # streaming completions send a stream of ContentDeltas ContentDelta = register_schema( Annotated[ - Union[TextDelta, ImageDelta, ToolCallDelta], + Union[TextDelta, ImageDelta, ToolCallDelta, InProgressToolCallDelta, FinalToolCallDelta], Field(discriminator="type"), ], name="ContentDelta", diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 61f0ee3f4..76e0b8d8e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import ( from llama_models.sku_list import resolve_model from llama_stack.apis.common.content_types import ( + FinalToolCallDelta, + InProgressToolCallDelta, TextDelta, ToolCallDelta, ToolCallParseStatus, @@ -342,6 +344,7 @@ class MetaReferenceInferenceImpl( if not ipython and token_result.text.startswith("<|python_tag|>"): ipython = True + # send the deprecated delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -351,6 +354,15 @@ class MetaReferenceInferenceImpl( ), ) ) + # send the new delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=InProgressToolCallDelta( + tool_call="", + ), + ) + ) continue if token_result.text == "<|eot_id|>": @@ -363,10 +375,7 @@ class MetaReferenceInferenceImpl( text = token_result.text if ipython: - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) + delta = InProgressToolCallDelta(tool_call=text) else: delta = TextDelta(text=text) @@ -375,6 +384,21 @@ class MetaReferenceInferenceImpl( assert len(token_result.logprobs) == 1 logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) + + if isinstance(delta, InProgressToolCallDelta): + deprecated_delta = ToolCallDelta( + tool_call=delta.tool_call, + parse_status=ToolCallParseStatus.in_progress, + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=deprecated_delta, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + ) + yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -391,6 +415,7 @@ class MetaReferenceInferenceImpl( parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: + # send the deprecated delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -401,8 +426,15 @@ class MetaReferenceInferenceImpl( stop_reason=stop_reason, ) ) - + # send the new delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=FinalToolCallDelta(tool_call=None), + ) + ) for tool_call in message.tool_calls: + # send the deprecated delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -413,6 +445,13 @@ class MetaReferenceInferenceImpl( stop_reason=stop_reason, ) ) + # send the new delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=FinalToolCallDelta(tool_call=tool_call), + ) + ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent(