chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
@ -17,10 +16,10 @@ class TelemetryDatasetMixin:
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
attribute_filters: list[QueryCondition],
attributes_to_save: list[str],
dataset_id: str,
max_depth: Optional[int] = None,
max_depth: int | None = None,
) -> None:
if self.datasetio_api is None:
raise RuntimeError("DatasetIO API not available")
@ -48,9 +47,9 @@ class TelemetryDatasetMixin:
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse:
traces = await self.query_traces(attribute_filters=attribute_filters)
spans = []

View file

@ -6,7 +6,7 @@
import json
from datetime import datetime
from typing import Dict, List, Optional, Protocol
from typing import Protocol
import aiosqlite
@ -16,18 +16,18 @@ from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Tra
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> list[Trace]: ...
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> Dict[str, SpanWithStatus]: ...
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> dict[str, SpanWithStatus]: ...
class SQLiteTraceStore(TraceStore):
@ -36,11 +36,11 @@ class SQLiteTraceStore(TraceStore):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> list[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
@ -112,9 +112,9 @@ class SQLiteTraceStore(TraceStore):
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> Dict[str, SpanWithStatus]:
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> dict[str, SpanWithStatus]:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:

View file

@ -7,8 +7,9 @@
import asyncio
import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from typing import Any, TypeVar
from pydantic import BaseModel
@ -25,13 +26,13 @@ def _prepare_for_json(value: Any) -> str:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
elif isinstance(value, (str, int, float, bool)):
elif isinstance(value, str | int | float | bool):
return value
elif hasattr(value, "_name_"):
return value._name_
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, (list, tuple, set)):
elif isinstance(value, list | tuple | set):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
@ -43,7 +44,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
def trace_protocol(cls: type[T]) -> type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.

View file

@ -10,9 +10,10 @@ import logging
import queue
import random
import threading
from collections.abc import Callable
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
from typing import Any
from llama_stack.apis.telemetry import (
LogSeverity,
@ -106,13 +107,13 @@ class BackgroundLogger:
class TraceContext:
spans: List[Span] = []
spans: list[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
@ -168,7 +169,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
@ -246,7 +247,7 @@ class TelemetryHandler(logging.Handler):
class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
def __init__(self, name: str, attributes: dict[str, Any] = None):
self.name = name
self.attributes = attributes
self.span = None
@ -316,11 +317,11 @@ class SpanContextManager:
return wrapper
def span(name: str, attributes: Dict[str, Any] = None):
def span(name: str, attributes: dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
def get_current_span() -> Span | None:
global CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")