RFC: Split the tool call delta type

This commit is contained in:
Ashwin Bharambe 2025-02-13 12:58:57 -08:00
parent efdd60014d
commit ed6b7a72d3
2 changed files with 71 additions and 6 deletions

View file

@ -98,6 +98,8 @@ class ToolCallParseStatus(Enum):
@json_schema_type
class ToolCallDelta(BaseModel):
"""Deprecated: use InProgressToolCallDelta or ParsedToolCallDelta instead"""
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
@ -107,10 +109,34 @@ class ToolCallDelta(BaseModel):
parse_status: ToolCallParseStatus
@json_schema_type
class InProgressToolCallDelta(BaseModel):
"""Delta sent when the tool call is in progress
:param type: Discriminator type of the content item. Always "in_progress_tool_call"
:param tool_call: The tool call that is currently being executed
"""
type: Literal["in_progress_tool_call"] = "in_progress_tool_call"
tool_call: str
@json_schema_type
class FinalToolCallDelta(BaseModel):
"""Delta sent when the tool call is complete
:param type: Discriminator type of the content item. Always "final_tool_call"
:param tool_call: The final parsed tool call. If the tool call failed to parse, this will be None.
"""
type: Literal["final_tool_call"] = "final_tool_call"
tool_call: Optional[ToolCall]
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Union[TextDelta, ImageDelta, ToolCallDelta, InProgressToolCallDelta, FinalToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",

View file

@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import (
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
FinalToolCallDelta,
InProgressToolCallDelta,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
@ -342,6 +344,7 @@ class MetaReferenceInferenceImpl(
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
# send the deprecated delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -351,6 +354,15 @@ class MetaReferenceInferenceImpl(
),
)
)
# send the new delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=InProgressToolCallDelta(
tool_call="",
),
)
)
continue
if token_result.text == "<|eot_id|>":
@ -363,10 +375,7 @@ class MetaReferenceInferenceImpl(
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
delta = InProgressToolCallDelta(tool_call=text)
else:
delta = TextDelta(text=text)
@ -375,6 +384,21 @@ class MetaReferenceInferenceImpl(
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if isinstance(delta, InProgressToolCallDelta):
deprecated_delta = ToolCallDelta(
tool_call=delta.tool_call,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=deprecated_delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -391,6 +415,7 @@ class MetaReferenceInferenceImpl(
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
# send the deprecated delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -401,8 +426,15 @@ class MetaReferenceInferenceImpl(
stop_reason=stop_reason,
)
)
# send the new delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=FinalToolCallDelta(tool_call=None),
)
)
for tool_call in message.tool_calls:
# send the deprecated delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -413,6 +445,13 @@ class MetaReferenceInferenceImpl(
stop_reason=stop_reason,
)
)
# send the new delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=FinalToolCallDelta(tool_call=tool_call),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(