forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import BasicScoringConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -66,7 +66,7 @@ class BasicScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
@ -82,7 +82,7 @@ class BasicScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
@ -107,8 +107,8 @@ class BasicScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output
|
|||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
|
@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
|
@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "docvqa",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "equality",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "subset_of",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -11,8 +11,8 @@ import logging
|
|||
import random
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Iterable, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import emoji
|
||||
import langdetect
|
||||
|
@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]:
|
|||
The separator for hindi is '।'
|
||||
"""
|
||||
for line in lines.splitlines():
|
||||
for sent in re.findall(
|
||||
yield from re.findall(
|
||||
r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?",
|
||||
line.strip(),
|
||||
flags=re.U,
|
||||
):
|
||||
yield sent
|
||||
)
|
||||
|
||||
|
||||
def count_words_cjk(text: str) -> int:
|
||||
|
@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int:
|
|||
return non_asian_words_cnt + asian_chars_cnt + emoji_cnt
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _get_sentence_tokenizer():
|
||||
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
|
||||
|
||||
|
@ -1719,8 +1718,8 @@ def count_sentences(text):
|
|||
return len(tokenized_sentences)
|
||||
|
||||
|
||||
def get_langid(text: str, lid_path: Optional[str] = None) -> str:
|
||||
line_langs: List[str] = []
|
||||
def get_langid(text: str, lid_path: str | None = None) -> str:
|
||||
line_langs: list[str] = []
|
||||
lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]
|
||||
|
||||
for line in lines:
|
||||
|
@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords):
|
|||
|
||||
|
||||
"""Library of instructions"""
|
||||
_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]]
|
||||
_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None
|
||||
|
||||
_LANGUAGES = LANGUAGE_CODES
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
|
@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str:
|
|||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
assert string == f"{ia}/{ib}"
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue