forked from phoenix-oss/llama-stack-mirror
chore: more mypy fixes (#2029)
# What does this PR do? Mainly tried to cover the entire llama_stack/apis directory, we only have one left. Some excludes were just noop. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
feb9eb8b0d
commit
1a529705da
27 changed files with 581 additions and 166 deletions
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
|
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
class ScoringFnParamsType(StrEnum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
class AggregationFunctionType(StrEnum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
|
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||
judge_model: str
|
||||
prompt_template: str | None = None
|
||||
judge_score_regexes: list[str] | None = Field(
|
||||
judge_score_regexes: list[str] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: list[str] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||
parsing_regexes: list[str] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
def provider_scoring_fn_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue