content part fixes

This commit is contained in:
Ashwin Bharambe 2025-08-13 07:08:41 -07:00
parent e48d062233
commit 6bd215706d
No known key found for this signature in database
GPG key ID: A7318BD657B83EA8
5 changed files with 271 additions and 33 deletions

View file

@ -624,19 +624,23 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
@json_schema_type
class OpenAIResponseContentPart(BaseModel):
"""Base class for response content parts."""
id: str
type: str
class OpenAIResponseContentPartOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str
# TODO: add annotations, logprobs, etc.
@json_schema_type
class OpenAIResponseContentPartText(OpenAIResponseContentPart):
"""Text content part for streaming responses."""
class OpenAIResponseContentPartRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str
text: str
type: Literal["text"] = "text"
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@json_schema_type

View file

@ -20,7 +20,7 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseContentPartText,
OpenAIResponseContentPartOutputText,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -481,7 +481,6 @@ class OpenAIResponsesImpl:
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_id: str | None = None
content_part_emitted = False
async for chunk in completion_result:
@ -493,14 +492,12 @@ class OpenAIResponsesImpl:
if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_id = f"cp_text_{uuid.uuid4()}"
content_part_emitted = True
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartText(
id=content_part_id,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
sequence_number=sequence_number,
@ -618,14 +615,13 @@ class OpenAIResponsesImpl:
tool_calls = None
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted and content_part_id:
if content_part_emitted:
final_text = "".join(chat_response_content)
sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartText(
id=content_part_id,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=sequence_number,