chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -10,8 +10,8 @@ import re
import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
import httpx
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def turn_to_messages(self, turn: Turn) -> List[Message]:
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
turn_id: str | None = None,
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: List[Message],
input_messages: list[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
documents: list[Document] | None = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[str],
messages: list[Message],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: List[Message],
input_messages: list[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
documents: list[Document] | None = None,
) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
"""Parse a toolgroup name into its components.
Args:
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
def _interpret_content_as_attachment(
content: str,
) -> Optional[Attachment]:
) -> Attachment | None:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)

View file

@ -8,7 +8,7 @@ import json
import logging
import shutil
import uuid
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from llama_stack.apis.agents import (
Agent,
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, temperature, tools

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -6,7 +6,8 @@
import json
import uuid
from typing import AsyncIterator, List, Optional, Union, cast
from collections.abc import AsyncIterator
from typing import cast
from openai.types.chat import ChatCompletionToolParam
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = []
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
messages: List[OpenAIMessageParam] = []
messages: list[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: List[OpenAIResponseOutput] = []
output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
return response
async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = []
self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]:
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
model_id: str,
stream: bool,
chat_response: OpenAIChatCompletion,
messages: List[OpenAIMessageParam],
messages: list[OpenAIMessageParam],
temperature: float,
) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = []
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]:
) -> ToolInvocationResult | None:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}

View file

@ -8,7 +8,6 @@ import json
import logging
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
vector_db_id: str | None = None
started_at: datetime
access_attributes: Optional[AccessAttributes] = None
access_attributes: AccessAttributes | None = None
class AgentPersistence:
@ -55,7 +54,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -78,7 +77,7 @@ class AgentPersistence:
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
@ -106,7 +105,7 @@ class AgentPersistence:
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
@ -125,7 +124,7 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
@ -145,7 +144,7 @@ class AgentPersistence:
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id):
return None
@ -163,7 +162,7 @@ class AgentPersistence:
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id):
return None

View file

@ -6,7 +6,6 @@
import asyncio
import logging
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[str] = None,
output_shields: List[str] = None,
input_shields: list[str] = None,
output_shields: list[str] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(