chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert(
self,
documents: List[RAGDocument],
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
parameter_type: str
description: str
required: bool = Field(default=True)
default: Optional[Any] = None
default: Any | None = None
@json_schema_type
@ -40,39 +40,39 @@ class Tool(Resource):
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
metadata: Optional[Dict[str, Any]] = None
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolDef(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
description: str | None = None
parameters: list[ToolParameter] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
args: dict[str, Any] | None = None
mcp_endpoint: URL | None = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None
@json_schema_type
class ToolInvocationResult(BaseModel):
content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
content: InterleavedContent | None = None
error_message: str | None = None
error_code: int | None = None
metadata: dict[str, Any] | None = None
class ToolStore(Protocol):
@ -81,11 +81,11 @@ class ToolStore(Protocol):
class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup]
data: list[ToolGroup]
class ListToolsResponse(BaseModel):
data: List[Tool]
data: list[Tool]
class ListToolDefsResponse(BaseModel):
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
"""Register a tool group"""
...
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
...
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group"""
...
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...