introduce and use a generic ContentDelta

This commit is contained in:
Ashwin Bharambe 2025-01-13 19:38:44 -08:00
parent 9ec54dcbe7
commit aced2ce07e
2 changed files with 44 additions and 17 deletions

View file

@ -5,10 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, field_serializer, model_validator from pydantic import BaseModel, Field, field_serializer, model_validator
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]], Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent", name="InterleavedContent",
) )
class TextDelta(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageDelta(BaseModel):
type: Literal["image"] = "image"
data: bytes
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failed = "failed"
succeeded = "succeeded"
@json_schema_type
class ToolCallDelta(BaseModel):
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)

View file

@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress" progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type @json_schema_type
class ChatCompletionResponseEvent(BaseModel): class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event.""" """Chat completion response event."""
event_type: ChatCompletionResponseEventType event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta] delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None