aggregation function config

This commit is contained in:
Xi Yan 2024-12-10 15:46:46 -08:00
parent e2054d53e4
commit fbc3888fd7
10 changed files with 189 additions and 26 deletions

View file

@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel):
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
]