alternative

This commit is contained in:
Xi Yan 2025-03-11 23:14:35 -07:00
parent cd3a3a5e26
commit bc71980769

View file

@ -25,18 +25,6 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
class ScoringFnParamsType(Enum):
"""
A type of scoring function parameters.
:cvar llm_as_judge: Provide judge model and prompt template.
:cvar regex_parser: Provide regexes to parse the answer from the generated response.
:cvar basic: Parameters for basic non-parameterized scoring function.
"""
custom_llm_as_judge = "custom_llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
class ScoringFunctionType(Enum):
@ -80,81 +68,33 @@ class AggregationFunctionType(Enum):
accuracy = "accuracy"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""
Parameters for a scoring function that uses a judge model to score the answer.
# TODO(xiyan):
# ============= OPTION 1: SEPARATE ScoringFnParamsType + ScoringFunctionType =============
# class ScoringFnParamsType(Enum):
# """
# A type of scoring function parameters.
:param judge_model: The model to use for scoring.
:param prompt_template: (Optional) The prompt template to use for scoring.
:param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
# :cvar llm_as_judge: Provide judge model and prompt template.
# :cvar regex_parser: Provide regexes to parse the answer from the generated response.
# :cvar basic: Parameters for basic non-parameterized scoring function.
# """
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
# custom_llm_as_judge = "custom_llm_as_judge"
# regex_parser = "regex_parser"
# basic = "basic"
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""
Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
:param parsing_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["regex_parser"] = "regex_parser"
parsing_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""
Parameters for a non-parameterized scoring function.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["basic"] = "basic"
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = register_schema(
Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
],
name="ScoringFnParams",
)
# TODO(xiyan): ALTERNATIVE OPTION, merge ScoringFnParamsType + ScoringFunctionType
# @json_schema_type
# class LLMAsJudgeScoringFnParams(BaseModel):
# """
# Parameters for a scoring function that uses a judge model to score the answer.
# :param judge_model: The model to use for scoring.
# :param prompt_template: (Optional) The prompt template to use for scoring.
# :param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response.
# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
# """
# type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
# judge_model: str
# prompt_template: Optional[str] = None
@ -168,7 +108,16 @@ ScoringFnParams = register_schema(
# )
# class RegexParserScoringFnParamsCommon(BaseModel):
# @json_schema_type
# class RegexParserScoringFnParams(BaseModel):
# """
# Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
# :param parsing_regexes: Regexes to extract the answer from generated response
# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
# """
# type: Literal["regex_parser"] = "regex_parser"
# parsing_regexes: Optional[List[str]] = Field(
# description="Regexes to extract the answer from generated response",
# default_factory=list,
@ -180,46 +129,104 @@ ScoringFnParams = register_schema(
# @json_schema_type
# class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon):
# type: Literal["regex_parser"] = "regex_parser"
# class BasicScoringFnParams(BaseModel):
# """
# Parameters for a non-parameterized scoring function.
# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
# """
# @json_schema_type
# class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon):
# type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
# class BasicScoringFnParamsCommon(BaseModel):
# type: Literal["basic"] = "basic"
# aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
# description="Aggregation functions to apply to the scores of each row",
# default_factory=list,
# )
# @json_schema_type
# class EqualityScoringFnParams(BasicScoringFnParamsCommon):
# type: Literal["equality"] = "equality"
# @json_schema_type
# class SubsetOfcoringFnParams(BasicScoringFnParamsCommon):
# type: Literal["subset_of"] = "subset_of"
# ScoringFnParams = register_schema(
# Annotated[
# Union[
# LLMAsJudgeScoringFnParams,
# RegexParserScoringFnParams,
# RegexParserMathScoringFnParams,
# EqualityScoringFnParams,
# SubsetOfcoringFnParams,
# BasicScoringFnParams,
# ],
# Field(discriminator="type"),
# ],
# name="ScoringFnParams",
# )
# ============= END OF OPTION 1 =============
# TODO(xiyan):
# ============= OPTION 2: MERGE ScoringFnParamsType + ScoringFunctionType into ScoringFunctionType =============
class RegexParserScoringFnParamsCommon(BaseModel):
parsing_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
class BasicScoringFnParamsCommon(BaseModel):
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon):
type: Literal["regex_parser"] = "regex_parser"
@json_schema_type
class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon):
type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
@json_schema_type
class EqualityScoringFnParams(BasicScoringFnParamsCommon):
type: Literal["equality"] = "equality"
@json_schema_type
class SubsetOfcoringFnParams(BasicScoringFnParamsCommon):
type: Literal["subset_of"] = "subset_of"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = register_schema(
Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
RegexParserMathScoringFnParams,
EqualityScoringFnParams,
SubsetOfcoringFnParams,
],
Field(discriminator="type"),
],
name="ScoringFnParams",
)
class CommonScoringFnFields(BaseModel):
"""
@ -277,7 +284,8 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
scoring_fn_type: ScoringFunctionType,
# TODO(xiyan): scoring_fn_type will not be needed for OPTION 2
# scoring_fn_type: ScoringFunctionType,
params: Optional[ScoringFnParams] = None,
scoring_fn_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
@ -286,7 +294,7 @@ class ScoringFunctions(Protocol):
Register a new scoring function with given parameters.
Only valid scoring function type that can be parameterized can be registered.
:param scoring_fn_type: The type of scoring function to register.
# :param scoring_fn_type: The type of scoring function to register.
:param params: The parameters for the scoring function.
:param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated.
:param metadata: (Optional) Any additional metadata to be associated with the scoring function.