chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,20 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(Enum):
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation]
violation: SafetyViolation | None
@json_schema_type
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@ -166,18 +151,13 @@ class Turn(BaseModel):
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
output_attachments: list[Attachment] | None = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
completed_at: datetime | None = None
@json_schema_type
@ -186,34 +166,31 @@ class Session(BaseModel):
session_id: str
session_name: str
turns: List[Turn]
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
args: dict[str, Any]
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=list)
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: Optional[int] = 10
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
model: str
instructions: str
name: Optional[str] = None
enable_session_persistence: Optional[bool] = False
response_format: Optional[ResponseFormat] = None
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
@ -257,16 +234,16 @@ class Agent(BaseModel):
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
instructions: str | None = None
class AgentTurnResponseEventType(Enum):
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=dict)
@json_schema_type
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
messages: list[UserMessage | ToolResponseMessage]
documents: Optional[List[Document]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
@ -429,17 +399,12 @@ class Agents(Protocol):
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
@ -463,9 +428,9 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
@ -538,7 +503,7 @@ class Agents(Protocol):
self,
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
@ -623,14 +588,14 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
:param input: Input message(s) to create the response.