mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
RFC: Split the tool call delta type
This commit is contained in:
parent
efdd60014d
commit
ed6b7a72d3
2 changed files with 71 additions and 6 deletions
|
@ -98,6 +98,8 @@ class ToolCallParseStatus(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
"""Deprecated: use InProgressToolCallDelta or ParsedToolCallDelta instead"""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
|
||||
# you either send an in-progress tool call so the client can stream a long
|
||||
|
@ -107,10 +109,34 @@ class ToolCallDelta(BaseModel):
|
|||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InProgressToolCallDelta(BaseModel):
|
||||
"""Delta sent when the tool call is in progress
|
||||
|
||||
:param type: Discriminator type of the content item. Always "in_progress_tool_call"
|
||||
:param tool_call: The tool call that is currently being executed
|
||||
"""
|
||||
|
||||
type: Literal["in_progress_tool_call"] = "in_progress_tool_call"
|
||||
tool_call: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FinalToolCallDelta(BaseModel):
|
||||
"""Delta sent when the tool call is complete
|
||||
|
||||
:param type: Discriminator type of the content item. Always "final_tool_call"
|
||||
:param tool_call: The final parsed tool call. If the tool call failed to parse, this will be None.
|
||||
"""
|
||||
|
||||
type: Literal["final_tool_call"] = "final_tool_call"
|
||||
tool_call: Optional[ToolCall]
|
||||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta, InProgressToolCallDelta, FinalToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
|
|
|
@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import (
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
FinalToolCallDelta,
|
||||
InProgressToolCallDelta,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
|
@ -342,6 +344,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
# send the deprecated delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -351,6 +354,15 @@ class MetaReferenceInferenceImpl(
|
|||
),
|
||||
)
|
||||
)
|
||||
# send the new delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=InProgressToolCallDelta(
|
||||
tool_call="",
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
@ -363,10 +375,7 @@ class MetaReferenceInferenceImpl(
|
|||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
delta = InProgressToolCallDelta(tool_call=text)
|
||||
else:
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
|
@ -375,6 +384,21 @@ class MetaReferenceInferenceImpl(
|
|||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
if isinstance(delta, InProgressToolCallDelta):
|
||||
deprecated_delta = ToolCallDelta(
|
||||
tool_call=delta.tool_call,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=deprecated_delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -391,6 +415,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
# send the deprecated delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -401,8 +426,15 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# send the new delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=FinalToolCallDelta(tool_call=None),
|
||||
)
|
||||
)
|
||||
for tool_call in message.tool_calls:
|
||||
# send the deprecated delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -413,6 +445,13 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
# send the new delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=FinalToolCallDelta(tool_call=tool_call),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue