move all implementations to use updated type

This commit is contained in:
Ashwin Bharambe 2025-01-13 20:04:19 -08:00
parent aced2ce07e
commit 9a5803a429
8 changed files with 139 additions and 208 deletions

View file

@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.common.content_types import (
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin):
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta="",
tool_call_delta=delta,
delta=delta,
)
)
)
elif isinstance(delta, str):
content += delta
elif delta.type == "text":
content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta=event.delta,
delta=delta,
)
)
)

View file

@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
]
yield CompletionResponseStreamChunk(
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens,
)
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
delta=TextDelta(text=""),
)
)
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)