From aced2ce07e02295d13db11be39156adb24ce07f3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 19:38:44 -0800 Subject: [PATCH] introduce and use a generic ContentDelta --- llama_stack/apis/common/content_types.py | 43 +++++++++++++++++++++++- llama_stack/apis/inference/inference.py | 18 ++-------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 629e0e94d..3b61fa243 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -5,10 +5,12 @@ # the root directory of this source tree. import base64 +from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.schema_utils import json_schema_type, register_schema +from llama_models.llama3.api.datatypes import ToolCall +from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, field_serializer, model_validator @@ -60,3 +62,42 @@ InterleavedContent = register_schema( Union[str, InterleavedContentItem, List[InterleavedContentItem]], name="InterleavedContent", ) + + +class TextDelta(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ImageDelta(BaseModel): + type: Literal["image"] = "image" + data: bytes + + +@json_schema_type +class ToolCallParseStatus(Enum): + started = "started" + in_progress = "in_progress" + failed = "failed" + succeeded = "succeeded" + + +@json_schema_type +class ToolCallDelta(BaseModel): + type: Literal["tool_call"] = "tool_call" + + # you either send an in-progress tool call so the client can stream a long + # code generation or you send the final parsed tool call at the end of the + # stream + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + +# streaming completions send a stream of ContentDeltas +ContentDelta = register_schema( + Annotated[ + Union[TextDelta, ImageDelta, ToolCallDelta], + Field(discriminator="type"), + ], + name="ContentDelta", +) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4a453700c..b525aa331 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.models import Model from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ToolCallParseStatus(Enum): - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - -@json_schema_type -class ToolCallDelta(BaseModel): - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - @json_schema_type class ChatCompletionResponseEvent(BaseModel): """Chat completion response event.""" event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] + delta: ContentDelta logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None