forked from phoenix-oss/llama-stack-mirror
introduce and use a generic ContentDelta
This commit is contained in:
parent
9ec54dcbe7
commit
aced2ce07e
2 changed files with 44 additions and 17 deletions
|
@ -5,10 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
from llama_models.llama3.api.datatypes import ToolCall
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema
|
||||
from pydantic import BaseModel, Field, field_serializer, model_validator
|
||||
|
||||
|
||||
|
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
|
|||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
|
||||
|
||||
class TextDelta(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageDelta(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallParseStatus(Enum):
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
succeeded = "succeeded"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
|
||||
# you either send an in-progress tool call so the client can stream a long
|
||||
# code generation or you send the final parsed tool call at the end of the
|
||||
# stream
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
|
|
|
@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
|
|||
progress = "progress"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallParseStatus(Enum):
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolCallDelta(BaseModel):
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseEvent(BaseModel):
|
||||
"""Chat completion response event."""
|
||||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
delta: ContentDelta
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: Optional[StopReason] = None
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue