scoring function type

This commit is contained in:
Xi Yan 2025-03-11 21:50:25 -07:00
parent 70fdf6c04b
commit b3ee4c00ce

View file

@ -36,7 +36,7 @@ class ScoringFnParamsType(Enum):
:cvar basic: Parameters for basic non-parameterized scoring function. :cvar basic: Parameters for basic non-parameterized scoring function.
""" """
llm_as_judge = "llm_as_judge" custom_llm_as_judge = "custom_llm_as_judge"
regex_parser = "regex_parser" regex_parser = "regex_parser"
basic = "basic" basic = "basic"
@ -50,13 +50,20 @@ class ScoringFunctionType(Enum):
:cvar regex_parser: Scoring function that parses the answer from the generated response using regexes, and checks against the expected answer. :cvar regex_parser: Scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
""" """
llm_as_judge = "llm_as_judge" custom_llm_as_judge = "custom_llm_as_judge"
regex_parser = "regex_parser" regex_parser = "regex_parser"
# NOTE: add additional scoring function types that can be registered regex_parser_math_response = "regex_parser_math_response"
# equality = "equality" equality = "equality"
# subset_of = "subset_of" subset_of = "subset_of"
# valid_json = "valid_json" factuality = "factuality"
# text_quality = "text_quality" faithfulness = "faithfulness"
answer_correctness = "answer_correctness"
answer_relevancy = "answer_relevancy"
answer_similarity = "answer_similarity"
context_entity_recall = "context_entity_recall"
context_precision = "context_precision"
context_recall = "context_recall"
context_relevancy = "context_relevancy"
@json_schema_type @json_schema_type
@ -118,18 +125,16 @@ ScoringFnParams = register_schema(
class CommonScoringFnFields(BaseModel): class CommonScoringFnFields(BaseModel):
scoring_fn_type: ScoringFunctionType
description: Optional[str] = None description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field( params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None, default=None,
) )
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
@json_schema_type @json_schema_type
@ -170,6 +175,7 @@ class ScoringFunctions(Protocol):
params: Optional[ScoringFnParams] = None, params: Optional[ScoringFnParams] = None,
scoring_fn_id: Optional[str] = None, scoring_fn_id: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
): ):
""" """
Register a new scoring function with given parameters. Register a new scoring function with given parameters.
@ -179,5 +185,6 @@ class ScoringFunctions(Protocol):
:param params: The parameters for the scoring function. :param params: The parameters for the scoring function.
:param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated. :param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated.
:param description: (Optional) The description of the scoring function. :param description: (Optional) The description of the scoring function.
:param metadata: (Optional) Any additional metadata to be associated with the scoring function.
""" """
... ...