Improve OpenAI responses type safety with Sequence and match statements

- Change list to Sequence in OpenAI response API types to fix list invariance issues
- Use match statements for proper union type narrowing in stream chunk handling
- Reduces errors in openai_responses.py from 76 to 12 (84% reduction)
This commit is contained in:
Ashwin Bharambe 2025-10-28 14:41:01 -07:00
parent 051af6e892
commit d4d55bc0fe
2 changed files with 32 additions and 26 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Sequence
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
scenarios. scenarios.
""" """
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent] content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message" type: Literal["message"] = "message"
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
""" """
id: str id: str
queries: list[str] queries: Sequence[str]
status: str status: str
type: Literal["file_search_call"] = "file_search_call" type: Literal["file_search_call"] = "file_search_call"
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type @json_schema_type
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
id: str id: str
model: str model: str
object: Literal["response"] = "response" object: Literal["response"] = "response"
output: list[OpenAIResponseOutput] output: Sequence[OpenAIResponseOutput]
parallel_tool_calls: bool = False parallel_tool_calls: bool = False
previous_response_id: str | None = None previous_response_id: str | None = None
prompt: OpenAIResponsePrompt | None = None prompt: OpenAIResponsePrompt | None = None
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always. # before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None truncation: str | None = None
usage: OpenAIResponseUsage | None = None usage: OpenAIResponseUsage | None = None
instructions: str | None = None instructions: str | None = None
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
:param object: Object type identifier, always "list" :param object: Object type identifier, always "list"
""" """
data: list[OpenAIResponseInput] data: Sequence[OpenAIResponseInput]
object: Literal["list"] = "list" object: Literal["list"] = "list"
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
:param input: List of input items that led to this response :param input: List of input items that led to this response
""" """
input: list[OpenAIResponseInput] input: Sequence[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject: def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field.""" """Convert to OpenAIResponseObject by excluding input field."""
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
:param object: Object type identifier, always "list" :param object: Object type identifier, always "list"
""" """
data: list[OpenAIResponseObjectWithInput] data: Sequence[OpenAIResponseObjectWithInput]
has_more: bool has_more: bool
first_id: str first_id: str
last_id: str last_id: str

View file

@ -289,7 +289,8 @@ class OpenAIResponsesImpl:
failed_response = None failed_response = None
async for stream_chunk in stream_gen: async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}: match stream_chunk.type:
case "response.completed" | "response.incomplete":
if final_response is not None: if final_response is not None:
raise ValueError( raise ValueError(
"The response stream produced multiple terminal responses! " "The response stream produced multiple terminal responses! "
@ -297,8 +298,10 @@ class OpenAIResponsesImpl:
) )
final_response = stream_chunk.response final_response = stream_chunk.response
final_event_type = stream_chunk.type final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed": case "response.failed":
failed_response = stream_chunk.response failed_response = stream_chunk.response
case _:
pass # Other event types don't have .response
if failed_response is not None: if failed_response is not None:
error_message = ( error_message = (
@ -370,14 +373,16 @@ class OpenAIResponsesImpl:
output_items = [] output_items = []
async for stream_chunk in orchestrator.create_response(): async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}: match stream_chunk.type:
case "response.completed" | "response.incomplete":
final_response = stream_chunk.response final_response = stream_chunk.response
elif stream_chunk.type == "response.failed": case "response.failed":
failed_response = stream_chunk.response failed_response = stream_chunk.response
case "response.output_item.done":
if stream_chunk.type == "response.output_item.done":
item = stream_chunk.item item = stream_chunk.item
output_items.append(item) output_items.append(item)
case _:
pass # Other event types
# Store and sync before yielding terminal events # Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event # This ensures the storage/syncing happens even if the consumer breaks after receiving the event