Address PR review feedback

- Simplify provider_resource_id assignment with assertion (review comment 1)
- Fix comment placement order (review comment 2)
- Refactor tool_calls list building to avoid union-attr suppression (review comment 3)
- Rename response_format to response_format_dict to avoid shadowing (review comment 4)
- Update type: ignore comments for message.content with accurate explanation of OpenAI SDK type alias resolution issue (review comment 5)
- Add assertions in litellm_openai_mixin to validate provider_resource_id is not None

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-28 13:49:34 -07:00
parent 9032ba9097
commit dbd036e7b4
2 changed files with 20 additions and 16 deletions

View file

@ -195,8 +195,9 @@ class LiteLLMOpenAIMixin(
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model) model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str if model_obj.provider_resource_id is None:
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
# Convert input to list if it's a string # Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input input_list = [params.input] if isinstance(params.input, str) else params.input
@ -233,8 +234,9 @@ class LiteLLMOpenAIMixin(
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model) model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str if model_obj.provider_resource_id is None:
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(provider_resource_id), model=self.get_litellm_model_name(provider_resource_id),
@ -279,8 +281,9 @@ class LiteLLMOpenAIMixin(
raise ValueError("Model store is not initialized") raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model) model_obj = await self.model_store.get_model(params.model)
# Fallback to params.model ensures provider_resource_id is always str if model_obj.provider_resource_id is None:
provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(provider_resource_id), model=self.get_litellm_model_name(provider_resource_id),

View file

@ -289,9 +289,9 @@ def process_chat_completion_response(
logprobs=None, logprobs=None,
) )
else: else:
# Otherwise, return tool calls as normal
# Filter to only valid ToolCall objects # Filter to only valid ToolCall objects
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
# Otherwise, return tool calls as normal
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
tool_calls=valid_tool_calls, tool_calls=valid_tool_calls,
@ -530,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
} }
if hasattr(message, "tool_calls") and message.tool_calls: if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = [] # type: ignore[assignment] # dict allows Any value, stricter type expected tool_calls_list = []
for tc in message.tool_calls: for tc in message.tool_calls:
# The tool.tool_name can be a str or a BuiltinTool enum. If # The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string. # it's the latter, convert to a string.
@ -538,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
if isinstance(tool_name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value tool_name = tool_name.value
result["tool_calls"].append( # type: ignore[union-attr] # reassigned as list above, mypy can't track tool_calls_list.append(
{ {
"id": tc.call_id, "id": tc.call_id,
"type": "function", "type": "function",
@ -548,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
}, },
} }
) )
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
return result return result
@ -853,11 +854,11 @@ def _convert_openai_request_response_format(
if not response_format: if not response_format:
return None return None
# response_format can be a dict or a pydantic model # response_format can be a dict or a pydantic model
response_format = dict(response_format) # type: ignore[assignment] # OpenAIResponseFormatParam union needs dict conversion response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
if response_format.get("type", "") == "json_schema": # type: ignore[union-attr] # narrowed to dict but mypy doesn't track .get() if response_format_dict.get("type", "") == "json_schema":
return JsonSchemaResponseFormat( return JsonSchemaResponseFormat(
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
json_schema=response_format.get("json_schema", {}).get("schema", ""), # type: ignore[union-attr] # chained .get() on reassigned dict confuses mypy json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
) )
return None return None
@ -965,12 +966,12 @@ def openai_messages_to_messages(
for message in messages: for message in messages:
converted_message: Message converted_message: Message
if message.role == "system": if message.role == "system":
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI content union broader than Message content union converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # message.content uses list[AliasType] but mypy expects Iterable[BaseType] due to OpenAI SDK type alias resolution
elif message.role == "user": elif message.role == "user":
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI content union broader than Message content union converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # message.content uses list[AliasType] but mypy expects Iterable[BaseType] due to OpenAI SDK type alias resolution
elif message.role == "assistant": elif message.role == "assistant":
converted_message = CompletionMessage( converted_message = CompletionMessage(
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI content union broader than Message content union content=openai_content_to_content(message.content), # type: ignore[arg-type] # message.content uses list[AliasType] but mypy expects Iterable[BaseType] due to OpenAI SDK type alias resolution
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
) )
@ -978,7 +979,7 @@ def openai_messages_to_messages(
converted_message = ToolResponseMessage( converted_message = ToolResponseMessage(
role="tool", role="tool",
call_id=message.tool_call_id, call_id=message.tool_call_id,
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI content union broader than Message content union content=openai_content_to_content(message.content), # type: ignore[arg-type] # message.content uses list[AliasType] but mypy expects Iterable[BaseType] due to OpenAI SDK type alias resolution
) )
else: else:
raise ValueError(f"Unknown role {message.role}") raise ValueError(f"Unknown role {message.role}")