From bc71980769da24a9d52b2b6f678991b886196272 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 11 Mar 2025 23:14:35 -0700 Subject: [PATCH] alternative --- .../scoring_functions/scoring_functions.py | 218 +++++++++--------- 1 file changed, 113 insertions(+), 105 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index ddfb720f2..85c5ad403 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -25,18 +25,6 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? -class ScoringFnParamsType(Enum): - """ - A type of scoring function parameters. - - :cvar llm_as_judge: Provide judge model and prompt template. - :cvar regex_parser: Provide regexes to parse the answer from the generated response. - :cvar basic: Parameters for basic non-parameterized scoring function. - """ - - custom_llm_as_judge = "custom_llm_as_judge" - regex_parser = "regex_parser" - basic = "basic" class ScoringFunctionType(Enum): @@ -80,81 +68,33 @@ class AggregationFunctionType(Enum): accuracy = "accuracy" -@json_schema_type -class LLMAsJudgeScoringFnParams(BaseModel): - """ - Parameters for a scoring function that uses a judge model to score the answer. +# TODO(xiyan): +# ============= OPTION 1: SEPARATE ScoringFnParamsType + ScoringFunctionType ============= +# class ScoringFnParamsType(Enum): +# """ +# A type of scoring function parameters. - :param judge_model: The model to use for scoring. - :param prompt_template: (Optional) The prompt template to use for scoring. - :param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response. - :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. - """ +# :cvar llm_as_judge: Provide judge model and prompt template. +# :cvar regex_parser: Provide regexes to parse the answer from the generated response. +# :cvar basic: Parameters for basic non-parameterized scoring function. +# """ - type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" - judge_model: str - prompt_template: Optional[str] = None - judge_score_regexes: Optional[List[str]] = Field( - description="Regexes to extract the answer from generated response", - default_factory=list, - ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) +# custom_llm_as_judge = "custom_llm_as_judge" +# regex_parser = "regex_parser" +# basic = "basic" -@json_schema_type -class RegexParserScoringFnParams(BaseModel): - """ - Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer. - - :param parsing_regexes: Regexes to extract the answer from generated response - :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. - """ - - type: Literal["regex_parser"] = "regex_parser" - parsing_regexes: Optional[List[str]] = Field( - description="Regexes to extract the answer from generated response", - default_factory=list, - ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) - - -@json_schema_type -class BasicScoringFnParams(BaseModel): - """ - Parameters for a non-parameterized scoring function. - - :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. - """ - - type: Literal["basic"] = "basic" - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) - - -ScoringFnParams = register_schema( - Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, - ], - Field(discriminator="type"), - ], - name="ScoringFnParams", -) - - -# TODO(xiyan): ALTERNATIVE OPTION, merge ScoringFnParamsType + ScoringFunctionType # @json_schema_type # class LLMAsJudgeScoringFnParams(BaseModel): +# """ +# Parameters for a scoring function that uses a judge model to score the answer. + +# :param judge_model: The model to use for scoring. +# :param prompt_template: (Optional) The prompt template to use for scoring. +# :param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response. +# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. +# """ + # type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" # judge_model: str # prompt_template: Optional[str] = None @@ -168,7 +108,16 @@ ScoringFnParams = register_schema( # ) -# class RegexParserScoringFnParamsCommon(BaseModel): +# @json_schema_type +# class RegexParserScoringFnParams(BaseModel): +# """ +# Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer. + +# :param parsing_regexes: Regexes to extract the answer from generated response +# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. +# """ + +# type: Literal["regex_parser"] = "regex_parser" # parsing_regexes: Optional[List[str]] = Field( # description="Regexes to extract the answer from generated response", # default_factory=list, @@ -180,46 +129,104 @@ ScoringFnParams = register_schema( # @json_schema_type -# class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon): -# type: Literal["regex_parser"] = "regex_parser" +# class BasicScoringFnParams(BaseModel): +# """ +# Parameters for a non-parameterized scoring function. +# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. +# """ -# @json_schema_type -# class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon): -# type: Literal["regex_parser_math_response"] = "regex_parser_math_response" - - -# class BasicScoringFnParamsCommon(BaseModel): +# type: Literal["basic"] = "basic" # aggregation_functions: Optional[List[AggregationFunctionType]] = Field( # description="Aggregation functions to apply to the scores of each row", # default_factory=list, # ) -# @json_schema_type -# class EqualityScoringFnParams(BasicScoringFnParamsCommon): -# type: Literal["equality"] = "equality" - - -# @json_schema_type -# class SubsetOfcoringFnParams(BasicScoringFnParamsCommon): -# type: Literal["subset_of"] = "subset_of" - - # ScoringFnParams = register_schema( # Annotated[ # Union[ # LLMAsJudgeScoringFnParams, # RegexParserScoringFnParams, -# RegexParserMathScoringFnParams, -# EqualityScoringFnParams, -# SubsetOfcoringFnParams, +# BasicScoringFnParams, # ], # Field(discriminator="type"), # ], # name="ScoringFnParams", # ) +# ============= END OF OPTION 1 ============= + + +# TODO(xiyan): +# ============= OPTION 2: MERGE ScoringFnParamsType + ScoringFunctionType into ScoringFunctionType ============= +class RegexParserScoringFnParamsCommon(BaseModel): + parsing_regexes: Optional[List[str]] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, + ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) + + +class BasicScoringFnParamsCommon(BaseModel): + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) + + +@json_schema_type +class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon): + type: Literal["regex_parser"] = "regex_parser" + + +@json_schema_type +class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon): + type: Literal["regex_parser_math_response"] = "regex_parser_math_response" + + +@json_schema_type +class EqualityScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["equality"] = "equality" + + +@json_schema_type +class SubsetOfcoringFnParams(BasicScoringFnParamsCommon): + type: Literal["subset_of"] = "subset_of" + + +@json_schema_type +class LLMAsJudgeScoringFnParams(BaseModel): + type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" + judge_model: str + prompt_template: Optional[str] = None + judge_score_regexes: Optional[List[str]] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, + ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) + + +ScoringFnParams = register_schema( + Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + RegexParserMathScoringFnParams, + EqualityScoringFnParams, + SubsetOfcoringFnParams, + ], + Field(discriminator="type"), + ], + name="ScoringFnParams", +) + class CommonScoringFnFields(BaseModel): """ @@ -277,7 +284,8 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( self, - scoring_fn_type: ScoringFunctionType, + # TODO(xiyan): scoring_fn_type will not be needed for OPTION 2 + # scoring_fn_type: ScoringFunctionType, params: Optional[ScoringFnParams] = None, scoring_fn_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, @@ -286,7 +294,7 @@ class ScoringFunctions(Protocol): Register a new scoring function with given parameters. Only valid scoring function type that can be parameterized can be registered. - :param scoring_fn_type: The type of scoring function to register. + # :param scoring_fn_type: The type of scoring function to register. :param params: The parameters for the scoring function. :param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated. :param metadata: (Optional) Any additional metadata to be associated with the scoring function.