forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
|
@ -21,14 +21,14 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
|||
}
|
||||
|
||||
|
||||
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"weighted_average": sum(
|
||||
result["score"] * result["weight"]
|
||||
|
@ -40,14 +40,14 @@ def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[
|
|||
|
||||
|
||||
def aggregate_categorical_count(
|
||||
scoring_results: List[ScoringResultRow],
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
) -> dict[str, Any]:
|
||||
scores = [str(r["score"]) for r in scoring_results]
|
||||
unique_scores = sorted(set(scores))
|
||||
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
|
||||
|
||||
|
||||
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
scores = [r["score"] for r in scoring_results if r["score"] is not None]
|
||||
median = statistics.median(scores) if scores else None
|
||||
return {"median": median}
|
||||
|
@ -64,8 +64,8 @@ AGGREGATION_FUNCTIONS = {
|
|||
|
||||
|
||||
def aggregate_metrics(
|
||||
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType]
|
||||
) -> dict[str, Any]:
|
||||
agg_results = {}
|
||||
for metric in metrics:
|
||||
if metric not in AGGREGATION_FUNCTIONS:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
@ -28,28 +28,28 @@ class BaseScoringFn(ABC):
|
|||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: List[ScoringResultRow],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||
def get_supported_scoring_fn_defs(self) -> list[ScoringFn]:
|
||||
return list(self.supported_fn_defs_registry.values())
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
|
@ -81,18 +81,18 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: List[ScoringResultRow],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
|
||||
if scoring_params is not None:
|
||||
if params is None:
|
||||
|
@ -107,8 +107,8 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from collections.abc import Iterator
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
|
@ -15,7 +15,7 @@ class TimeoutError(Exception):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
def signal_handler(signum: int, frame: FrameType | None) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue