mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
Improve OpenAI responses type safety with Sequence and match statements
- Change list to Sequence in OpenAI response API types to fix list invariance issues - Use match statements for proper union type narrowing in stream chunk handling - Reduces errors in openai_responses.py from 76 to 12 (84% reduction)
This commit is contained in:
parent
051af6e892
commit
d4d55bc0fe
2 changed files with 32 additions and 26 deletions
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
|
||||||
scenarios.
|
scenarios.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
type: Literal["message"] = "message"
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
queries: list[str]
|
queries: Sequence[str]
|
||||||
status: str
|
status: str
|
||||||
type: Literal["file_search_call"] = "file_search_call"
|
type: Literal["file_search_call"] = "file_search_call"
|
||||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
object: Literal["response"] = "response"
|
object: Literal["response"] = "response"
|
||||||
output: list[OpenAIResponseOutput]
|
output: Sequence[OpenAIResponseOutput]
|
||||||
parallel_tool_calls: bool = False
|
parallel_tool_calls: bool = False
|
||||||
previous_response_id: str | None = None
|
previous_response_id: str | None = None
|
||||||
prompt: OpenAIResponsePrompt | None = None
|
prompt: OpenAIResponsePrompt | None = None
|
||||||
|
|
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
# before the field was added. New responses will have this set always.
|
# before the field was added. New responses will have this set always.
|
||||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
tools: list[OpenAIResponseTool] | None = None
|
tools: Sequence[OpenAIResponseTool] | None = None
|
||||||
truncation: str | None = None
|
truncation: str | None = None
|
||||||
usage: OpenAIResponseUsage | None = None
|
usage: OpenAIResponseUsage | None = None
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
|
||||||
:param object: Object type identifier, always "list"
|
:param object: Object type identifier, always "list"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[OpenAIResponseInput]
|
data: Sequence[OpenAIResponseInput]
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||||
:param input: List of input items that led to this response
|
:param input: List of input items that led to this response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input: list[OpenAIResponseInput]
|
input: Sequence[OpenAIResponseInput]
|
||||||
|
|
||||||
def to_response_object(self) -> OpenAIResponseObject:
|
def to_response_object(self) -> OpenAIResponseObject:
|
||||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||||
|
|
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
|
||||||
:param object: Object type identifier, always "list"
|
:param object: Object type identifier, always "list"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[OpenAIResponseObjectWithInput]
|
data: Sequence[OpenAIResponseObjectWithInput]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
first_id: str
|
first_id: str
|
||||||
last_id: str
|
last_id: str
|
||||||
|
|
|
||||||
|
|
@ -289,16 +289,19 @@ class OpenAIResponsesImpl:
|
||||||
failed_response = None
|
failed_response = None
|
||||||
|
|
||||||
async for stream_chunk in stream_gen:
|
async for stream_chunk in stream_gen:
|
||||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
match stream_chunk.type:
|
||||||
if final_response is not None:
|
case "response.completed" | "response.incomplete":
|
||||||
raise ValueError(
|
if final_response is not None:
|
||||||
"The response stream produced multiple terminal responses! "
|
raise ValueError(
|
||||||
f"Earlier response from {final_event_type}"
|
"The response stream produced multiple terminal responses! "
|
||||||
)
|
f"Earlier response from {final_event_type}"
|
||||||
final_response = stream_chunk.response
|
)
|
||||||
final_event_type = stream_chunk.type
|
final_response = stream_chunk.response
|
||||||
elif stream_chunk.type == "response.failed":
|
final_event_type = stream_chunk.type
|
||||||
failed_response = stream_chunk.response
|
case "response.failed":
|
||||||
|
failed_response = stream_chunk.response
|
||||||
|
case _:
|
||||||
|
pass # Other event types don't have .response
|
||||||
|
|
||||||
if failed_response is not None:
|
if failed_response is not None:
|
||||||
error_message = (
|
error_message = (
|
||||||
|
|
@ -370,14 +373,16 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
output_items = []
|
output_items = []
|
||||||
async for stream_chunk in orchestrator.create_response():
|
async for stream_chunk in orchestrator.create_response():
|
||||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
match stream_chunk.type:
|
||||||
final_response = stream_chunk.response
|
case "response.completed" | "response.incomplete":
|
||||||
elif stream_chunk.type == "response.failed":
|
final_response = stream_chunk.response
|
||||||
failed_response = stream_chunk.response
|
case "response.failed":
|
||||||
|
failed_response = stream_chunk.response
|
||||||
if stream_chunk.type == "response.output_item.done":
|
case "response.output_item.done":
|
||||||
item = stream_chunk.item
|
item = stream_chunk.item
|
||||||
output_items.append(item)
|
output_items.append(item)
|
||||||
|
case _:
|
||||||
|
pass # Other event types
|
||||||
|
|
||||||
# Store and sync before yielding terminal events
|
# Store and sync before yielding terminal events
|
||||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue