introduce and use a generic ContentDelta

This commit is contained in:
Ashwin Bharambe 2025-01-13 19:38:44 -08:00
parent 9ec54dcbe7
commit aced2ce07e
2 changed files with 44 additions and 17 deletions

View file

@ -5,10 +5,12 @@
# the root directory of this source tree.
import base64
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, field_serializer, model_validator
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
class TextDelta(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageDelta(BaseModel):
type: Literal["image"] = "image"
data: bytes
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failed = "failed"
succeeded = "succeeded"
@json_schema_type
class ToolCallDelta(BaseModel):
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)

View file

@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None