mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
RFC: Split the tool call delta type
This commit is contained in:
parent
efdd60014d
commit
ed6b7a72d3
2 changed files with 71 additions and 6 deletions
|
@ -98,6 +98,8 @@ class ToolCallParseStatus(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolCallDelta(BaseModel):
|
class ToolCallDelta(BaseModel):
|
||||||
|
"""Deprecated: use InProgressToolCallDelta or ParsedToolCallDelta instead"""
|
||||||
|
|
||||||
type: Literal["tool_call"] = "tool_call"
|
type: Literal["tool_call"] = "tool_call"
|
||||||
|
|
||||||
# you either send an in-progress tool call so the client can stream a long
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
|
@ -107,10 +109,34 @@ class ToolCallDelta(BaseModel):
|
||||||
parse_status: ToolCallParseStatus
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class InProgressToolCallDelta(BaseModel):
|
||||||
|
"""Delta sent when the tool call is in progress
|
||||||
|
|
||||||
|
:param type: Discriminator type of the content item. Always "in_progress_tool_call"
|
||||||
|
:param tool_call: The tool call that is currently being executed
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["in_progress_tool_call"] = "in_progress_tool_call"
|
||||||
|
tool_call: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FinalToolCallDelta(BaseModel):
|
||||||
|
"""Delta sent when the tool call is complete
|
||||||
|
|
||||||
|
:param type: Discriminator type of the content item. Always "final_tool_call"
|
||||||
|
:param tool_call: The final parsed tool call. If the tool call failed to parse, this will be None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["final_tool_call"] = "final_tool_call"
|
||||||
|
tool_call: Optional[ToolCall]
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = register_schema(
|
ContentDelta = register_schema(
|
||||||
Annotated[
|
Annotated[
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
Union[TextDelta, ImageDelta, ToolCallDelta, InProgressToolCallDelta, FinalToolCallDelta],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
],
|
],
|
||||||
name="ContentDelta",
|
name="ContentDelta",
|
||||||
|
|
|
@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
FinalToolCallDelta,
|
||||||
|
InProgressToolCallDelta,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
@ -342,6 +344,7 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
ipython = True
|
ipython = True
|
||||||
|
# send the deprecated delta
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -351,6 +354,15 @@ class MetaReferenceInferenceImpl(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# send the new delta
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=InProgressToolCallDelta(
|
||||||
|
tool_call="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
@ -363,10 +375,7 @@ class MetaReferenceInferenceImpl(
|
||||||
text = token_result.text
|
text = token_result.text
|
||||||
|
|
||||||
if ipython:
|
if ipython:
|
||||||
delta = ToolCallDelta(
|
delta = InProgressToolCallDelta(tool_call=text)
|
||||||
tool_call=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
delta = TextDelta(text=text)
|
delta = TextDelta(text=text)
|
||||||
|
|
||||||
|
@ -375,6 +384,21 @@ class MetaReferenceInferenceImpl(
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
|
|
||||||
|
if isinstance(delta, InProgressToolCallDelta):
|
||||||
|
deprecated_delta = ToolCallDelta(
|
||||||
|
tool_call=delta.tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=deprecated_delta,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -391,6 +415,7 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
# send the deprecated delta
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -401,8 +426,15 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# send the new delta
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=FinalToolCallDelta(tool_call=None),
|
||||||
|
)
|
||||||
|
)
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
|
# send the deprecated delta
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -413,6 +445,13 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# send the new delta
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=FinalToolCallDelta(tool_call=tool_call),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue