chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -12,7 +12,7 @@ from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -3,12 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class BasicScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -66,7 +66,7 @@ class BasicScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
async def list_scoring_functions(self) -> list[ScoringFn]:
scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
@ -82,7 +82,7 @@ class BasicScoringImpl(
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
@ -107,8 +107,8 @@ class BasicScoringImpl(
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():

View file

@ -6,7 +6,7 @@
import json
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output
from .fn_defs.bfcl import bfcl
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
contain_func_call = False
error = None
error_type = None
@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
}
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
return {"valid": x["valid"]}
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "bfcl",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "bfcl",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)

View file

@ -6,7 +6,7 @@
import json
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "docvqa",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "equality",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "subset_of",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -11,8 +11,8 @@ import logging
import random
import re
import string
from collections.abc import Iterable, Sequence
from types import MappingProxyType
from typing import Dict, Iterable, List, Optional, Sequence, Union
import emoji
import langdetect
@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]:
The separator for hindi is ''
"""
for line in lines.splitlines():
for sent in re.findall(
yield from re.findall(
r"[^!?。\.\!\?\\\\n।]+[!?。\.\!\?\\\\n।]?",
line.strip(),
flags=re.U,
):
yield sent
)
def count_words_cjk(text: str) -> int:
@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int:
return non_asian_words_cnt + asian_chars_cnt + emoji_cnt
@functools.lru_cache(maxsize=None)
@functools.cache
def _get_sentence_tokenizer():
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
@ -1719,8 +1718,8 @@ def count_sentences(text):
return len(tokenized_sentences)
def get_langid(text: str, lid_path: Optional[str] = None) -> str:
line_langs: List[str] = []
def get_langid(text: str, lid_path: str | None = None) -> str:
line_langs: list[str] = []
lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]
for line in lines:
@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords):
"""Library of instructions"""
_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]]
_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None
_LANGUAGES = LANGUAGE_CODES

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import re
from typing import Sequence
from collections.abc import Sequence
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str:
try:
ia = int(a)
ib = int(b)
assert string == "{}/{}".format(ia, ib)
assert string == f"{ia}/{ib}"
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
return new_string
except (ValueError, AssertionError):