fix(mypy): part-03 completely resolve meta reference responses impl typing issues (#3951)

## Summary
Resolves all mypy errors in meta reference agent OpenAI responses
implementation by adding proper type narrowing, None checks, and
Sequence type support.

## Changes
- Fixed streaming.py, openai_responses.py, utils.py, tool_executor.py,
agent_instance.py
- Added Sequence type support to schema generator (ensures correct JSON
schema generation)
- Applied union type narrowing and None checks throughout

## Test plan
- All modified files pass mypy type checking (0 errors)
- Schema generator produces correct `type: array` for Sequence types

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-29 08:07:15 -07:00 committed by GitHub
parent e5c27dbcbf
commit a4f97559d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 174 additions and 78 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Sequence
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
scenarios.
"""
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message"
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
"""
id: str
queries: list[str]
queries: Sequence[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
id: str
model: str
object: Literal["response"] = "response"
output: list[OpenAIResponseOutput]
output: Sequence[OpenAIResponseOutput]
parallel_tool_calls: bool = False
previous_response_id: str | None = None
prompt: OpenAIResponsePrompt | None = None
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None
tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseInput]
data: Sequence[OpenAIResponseInput]
object: Literal["list"] = "list"
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
:param input: List of input items that led to this response
"""
input: list[OpenAIResponseInput]
input: Sequence[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseObjectWithInput]
data: Sequence[OpenAIResponseObjectWithInput]
has_more: bool
first_id: str
last_id: str

View file

@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
input: str | list[OpenAIResponseInput],
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
# Convert Sequence to list for mutation
new_input_items = list(previous_response.input)
new_input_items.extend(previous_response.output)
if isinstance(input, str):
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
conversation: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
"""Process input with optional previous response context.
Returns:
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
input_items_data: list[OpenAIResponseInput] = []
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
@ -289,16 +292,19 @@ class OpenAIResponsesImpl:
failed_response = None
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
match stream_chunk.type:
case "response.completed" | "response.incomplete":
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
case "response.failed":
failed_response = stream_chunk.response
case _:
pass # Other event types don't have .response
if failed_response is not None:
error_message = (
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
assert text is not None, "text must not be None"
assert max_infer_iters is not None, "max_infer_iters must not be None"
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id, conversation
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
final_response = None
failed_response = None
output_items = []
# Type as ConversationItem to avoid list invariance issues
output_items: list[ConversationItem] = []
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if stream_chunk.type == "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
match stream_chunk.type:
case "response.completed" | "response.incomplete":
final_response = stream_chunk.response
case "response.failed":
failed_response = stream_chunk.response
case "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
case _:
pass # Other event types
# Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# Type as ConversationItem union to avoid list invariance issues
conversation_items: list[ConversationItem] = []
if isinstance(input, str):
conversation_items.append(

View file

@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str,
instructions: str | None,
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
self.prompt = prompt
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
has_tool_calls = (
isinstance(choice.message, OpenAIAssistantMessageParam)
and choice.message.tool_calls
and self.ctx.response_tools
)
if not has_tool_calls:
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
if not is_new_tool_call and response_tool_call is not None:
# Both should have functions since we're inside the tool_call.function check above
assert response_tool_call.function is not None
assert tool_call.function is not None
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]
# Ensure that arguments, if sent back to the inference provider, are not None
tool_call.function.arguments = tool_call.function.arguments or "{}"
if tool_call.function:
tool_call.function.arguments = tool_call.function.arguments or "{}"
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = tool_call.function.arguments
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
func = chat_response_tool_calls[tool_call_index].function
tool_call_name = func.name if func else ""
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
self.sequence_number += 1
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
item = OpenAIResponseOutputMessageMCPCall(
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
arguments="",
name=tool_call.function.name,
id=matching_item_id,
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
status="in_progress",
)
elif tool_call.function.name == "web_search":
item = OpenAIResponseOutputMessageWebSearchToolCall(
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
description=tool.description,
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
always_allowed = mcp_tool.allowed_tools.tool_names
# Call list_mcp_tools
tool_defs = None
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator:
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
# tool_context can be None when no tools are provided in the response request
if self.ctx.tool_context:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(
self.ctx.tool_context.tools_to_process, output_messages
):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
tools_to_process: list[OpenAIResponseInputTool] = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
for mcp_tool in listing.tools:
# mcp_tool is MCPListToolsTool which has a name: str field
self.previous_tools[mcp_tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
# Exhaustive check - all tool types should be handled above
raise AssertionError(f"Unexpected tool type: {type(tool)}")
return [convert_tool(tool) for tool in self.current_tools]

View file

@ -7,6 +7,7 @@
import asyncio
import re
import uuid
from collections.abc import Sequence
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
if isinstance(content, str):
return content
converted_parts = []
# Type with union to avoid list invariance issues
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
# Output can be None, use empty string as fallback
output_content = input_item.output if input_item.output is not None else ""
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
content=output_content,
tool_call_id=input_item.id,
)
)
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
elif isinstance(input_item, OpenAIResponseMessage):
# Narrow type to OpenAIResponseMessage which has content and role attributes
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
# Dynamic message type call - different message types have different content expectations
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
if len(tool_call_results):
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
# Assert name exists for json_schema format
assert text.format.get("name"), "json_schema format requires a name"
schema_name: str = text.format["name"] # type: ignore[assignment]
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
def _extract_citations_from_text(
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
# TODO: list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
flagged_categories = (
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
)
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
return message
# No violations found
return None
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""

View file

@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
return list_type # type: ignore[no-any-return]
def is_generic_sequence(typ: object) -> bool:
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
import collections.abc
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is collections.abc.Sequence
def unwrap_generic_sequence(typ: object) -> type:
"""
Extracts the item type of a Sequence type.
:param typ: The Sequence type `Sequence[T]`.
:returns: The item type `T`.
"""
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
def _unwrap_generic_sequence(typ: object) -> type:
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return sequence_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
"True if the specified type is a generic set, i.e. `Set[T]`."

View file

@ -18,10 +18,12 @@ from .inspection import (
TypeLike,
is_generic_dict,
is_generic_list,
is_generic_sequence,
is_type_optional,
is_type_union,
unwrap_generic_dict,
unwrap_generic_list,
unwrap_generic_sequence,
unwrap_optional_type,
unwrap_union_types,
)
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
if metadata is not None:
# type is Annotated[T, ...]
arg = typing.get_args(data_type)[0]
return python_type_to_name(arg)
return python_type_to_name(arg, force=force)
if force:
# generic types
if is_type_optional(data_type, strict=True):
inner_name = python_type_to_name(unwrap_optional_type(data_type))
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
return f"Optional__{inner_name}"
elif is_generic_list(data_type):
item_name = python_type_to_name(unwrap_generic_list(data_type))
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
return f"List__{item_name}"
elif is_generic_sequence(data_type):
# Treat Sequence the same as List for schema generation purposes
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
return f"List__{item_name}"
elif is_generic_dict(data_type):
key_type, value_type = unwrap_generic_dict(data_type)
key_name = python_type_to_name(key_type)
value_name = python_type_to_name(value_type)
key_name = python_type_to_name(key_type, force=True)
value_name = python_type_to_name(value_type, force=True)
return f"Dict__{key_name}__{value_name}"
elif is_type_union(data_type):
member_types = unwrap_union_types(data_type)
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
return f"Union__{member_names}"
# named system or user-defined type

View file

@ -111,7 +111,7 @@ def get_class_property_docstrings(
def docstring_to_schema(data_type: type) -> Schema:
short_description, long_description = get_class_docstrings(data_type)
schema: Schema = {
"title": python_type_to_name(data_type),
"title": python_type_to_name(data_type, force=True),
}
description = "\n".join(filter(None, [short_description, long_description]))
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
if origin_type is list:
(list_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(list_type)}
elif origin_type is collections.abc.Sequence:
# Treat Sequence the same as list for JSON schema (both are arrays)
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(sequence_type)}
elif origin_type is dict:
key_type, value_type = typing.get_args(typ)
if not (key_type is str or key_type is int or is_type_enum(key_type)):