Address PR feedback: improve code clarity and fix AllowedToolsFilter bug

- streaming.py: Extract has_tool_calls boolean for readability
- streaming.py: Replace nested function checks with assertions
- streaming.py: Fix AllowedToolsFilter to use tool_names instead of allowed/disallowed
- streaming.py: Add comment explaining tool_context can be None
- streaming.py, utils.py: Clarify Pydantic/dict compatibility comments
- utils.py: Document list invariance vs Sequence covariance in type signature
- utils.py: Clarify list_shields runtime availability comment
This commit is contained in:
Ashwin Bharambe 2025-10-28 15:47:31 -07:00
parent 84d78ff48a
commit 53c6f846d4
2 changed files with 24 additions and 13 deletions

View file

@ -229,7 +229,8 @@ class StreamingResponseOrchestrator:
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools, # type: ignore[arg-type] # ChatCompletionFunctionToolParam compatible with expected dict type
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -272,7 +273,12 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (isinstance(choice.message, OpenAIAssistantMessageParam) and choice.message.tool_calls and self.ctx.response_tools):
has_tool_calls = (
isinstance(choice.message, OpenAIAssistantMessageParam)
and choice.message.tool_calls
and self.ctx.response_tools
)
if not has_tool_calls:
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
@ -723,11 +729,12 @@ class StreamingResponseOrchestrator:
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call and response_tool_call is not None:
# Need to check function is not None
if response_tool_call.function and tool_call.function:
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Both should have functions since we're inside the tool_call.function check above
assert response_tool_call.function is not None
assert tool_call.function is not None
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Output Safety Validation for this chunk
if self.guardrail_ids:
@ -1012,7 +1019,7 @@ class StreamingResponseOrchestrator:
description=tool.description,
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Dict compatible with ChatCompletionFunctionToolParam
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
@ -1059,8 +1066,8 @@ class StreamingResponseOrchestrator:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.allowed # type: ignore[attr-defined]
never_allowed = mcp_tool.allowed_tools.disallowed # type: ignore[attr-defined]
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
always_allowed = mcp_tool.allowed_tools.tool_names
# Call list_mcp_tools
tool_defs = None
@ -1092,7 +1099,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
@ -1124,6 +1131,7 @@ class StreamingResponseOrchestrator:
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
# tool_context can be None when no tools are provided in the response request
if self.ctx.tool_context:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
@ -1225,7 +1233,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Dict compatible with ChatCompletionFunctionToolParam
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",

View file

@ -82,8 +82,10 @@ async def convert_chat_choice_to_response_message(
async def convert_response_content_to_chat_content(
content: (
str
# List types for exact matches (invariant)
| list[OpenAIResponseInputMessageContent]
| list[OpenAIResponseOutputMessageContent]
# Sequence for mixed content types (covariant - accepts list of subtypes)
| Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent]
),
) -> str | list[OpenAIChatCompletionContentPartParam]:
@ -335,7 +337,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.list_shields() # type: ignore[attr-defined] # Safety API routing_table access
# list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]