scoring updates

This commit is contained in:
Xi Yan 2025-03-12 21:54:12 -07:00
parent 7b50fdb2b1
commit 3a87562e8d
6 changed files with 1346 additions and 1466 deletions

View file

@ -67,7 +67,7 @@ class AggregationFunctionType(Enum):
accuracy = "accuracy"
class BasicScoringFnParamsFields(BaseModel):
class BasicScoringFnParams(BaseModel):
"""
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
"""
@ -78,7 +78,7 @@ class BasicScoringFnParamsFields(BaseModel):
)
class RegexParserScoringFnParamsFields(BaseModel):
class RegexParserScoringFnParams(BaseModel):
"""
:param parsing_regexes: (Optional) Regexes to extract the answer from generated response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
@ -93,7 +93,7 @@ class RegexParserScoringFnParamsFields(BaseModel):
default_factory=list,
)
class CustomLLMAsJudgeScoringFnParamsFields(BaseModel):
class CustomLLMAsJudgeScoringFnParams(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
@ -103,103 +103,103 @@ class CustomLLMAsJudgeScoringFnParamsFields(BaseModel):
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
class RegexParserScoringFn(BaseModel):
type: Literal["regex_parser"] = "regex_parser"
regex_parser: RegexParserScoringFnParamsFields
regex_parser: RegexParserScoringFnParams
@json_schema_type
class RegexParserMathScoringFnParams(BaseModel):
class RegexParserMathScoringFn(BaseModel):
type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
regex_parser_math_response: RegexParserScoringFnParamsFields
regex_parser_math_response: RegexParserScoringFnParams
@json_schema_type
class EqualityScoringFnParams(BaseModel):
class EqualityScoringFn(BaseModel):
type: Literal["equality"] = "equality"
equality: BasicScoringFnParamsFields
equality: BasicScoringFnParams
@json_schema_type
class SubsetOfcoringFnParams(BaseModel):
class SubsetOfScoringFn(BaseModel):
type: Literal["subset_of"] = "subset_of"
subset_of: BasicScoringFnParamsFields
subset_of: BasicScoringFnParams
@json_schema_type
class FactualityScoringFnParams(BaseModel):
class FactualityScoringFn(BaseModel):
type: Literal["factuality"] = "factuality"
factuality: BasicScoringFnParamsFields
factuality: BasicScoringFnParams
@json_schema_type
class FaithfulnessScoringFnParams(BaseModel):
class FaithfulnessScoringFn(BaseModel):
type: Literal["faithfulness"] = "faithfulness"
faithfulness: BasicScoringFnParamsFields
faithfulness: BasicScoringFnParams
@json_schema_type
class AnswerCorrectnessScoringFnParams(BaseModel):
class AnswerCorrectnessScoringFn(BaseModel):
type: Literal["answer_correctness"] = "answer_correctness"
answer_correctness: BasicScoringFnParamsFields
answer_correctness: BasicScoringFnParams
@json_schema_type
class AnswerRelevancyScoringFnParams(BaseModel):
class AnswerRelevancyScoringFn(BaseModel):
type: Literal["answer_relevancy"] = "answer_relevancy"
answer_relevancy: BasicScoringFnParamsFields
answer_relevancy: BasicScoringFnParams
@json_schema_type
class AnswerSimilarityScoringFnParams(BaseModel):
class AnswerSimilarityScoringFn(BaseModel):
type: Literal["answer_similarity"] = "answer_similarity"
answer_similarity: BasicScoringFnParamsFields
answer_similarity: BasicScoringFnParams
@json_schema_type
class ContextEntityRecallScoringFnParams(BaseModel):
class ContextEntityRecallScoringFn(BaseModel):
type: Literal["context_entity_recall"] = "context_entity_recall"
context_entity_recall: BasicScoringFnParamsFields
context_entity_recall: BasicScoringFnParams
@json_schema_type
class ContextPrecisionScoringFnParams(BaseModel):
class ContextPrecisionScoringFn(BaseModel):
type: Literal["context_precision"] = "context_precision"
context_precision: BasicScoringFnParamsFields
context_precision: BasicScoringFnParams
@json_schema_type
class ContextRecallScoringFnParams(BaseModel):
class ContextRecallScoringFn(BaseModel):
type: Literal["context_recall"] = "context_recall"
context_recall: BasicScoringFnParamsFields
context_recall: BasicScoringFnParams
@json_schema_type
class ContextRelevancyScoringFnParams(BaseModel):
class ContextRelevancyScoringFn(BaseModel):
type: Literal["context_relevancy"] = "context_relevancy"
context_relevancy: BasicScoringFnParamsFields
context_relevancy: BasicScoringFnParams
@json_schema_type
class CustomLLMAsJudgeScoringFnParams(BaseModel):
class CustomLLMAsJudgeScoringFn(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
custom_llm_as_judge: CustomLLMAsJudgeScoringFnParamsFields
custom_llm_as_judge: CustomLLMAsJudgeScoringFnParams
ScoringFnParams = register_schema(
ScoringFnDefinition = register_schema(
Annotated[
Union[
CustomLLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
RegexParserMathScoringFnParams,
EqualityScoringFnParams,
SubsetOfcoringFnParams,
FactualityScoringFnParams,
FaithfulnessScoringFnParams,
AnswerCorrectnessScoringFnParams,
AnswerRelevancyScoringFnParams,
AnswerSimilarityScoringFnParams,
ContextEntityRecallScoringFnParams,
ContextPrecisionScoringFnParams,
ContextRecallScoringFnParams,
ContextRelevancyScoringFnParams,
CustomLLMAsJudgeScoringFn,
RegexParserScoringFn,
RegexParserMathScoringFn,
EqualityScoringFn,
SubsetOfScoringFn,
FactualityScoringFn,
FaithfulnessScoringFn,
AnswerCorrectnessScoringFn,
AnswerRelevancyScoringFn,
AnswerSimilarityScoringFn,
ContextEntityRecallScoringFn,
ContextPrecisionScoringFn,
ContextRecallScoringFn,
ContextRelevancyScoringFn,
],
Field(discriminator="type"),
],
name="ScoringFnParams",
name="ScoringFnDefinition",
)
@ -208,7 +208,7 @@ class CommonScoringFnFields(BaseModel):
:param fn: The scoring function type and parameters.
:param metadata: (Optional) Any additional metadata for this definition (e.g. description).
"""
fn: ScoringFnParams
fn: ScoringFnDefinition
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition (e.g. description)",
@ -288,7 +288,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
fn: ScoringFnParams,
fn: ScoringFnDefinition,
scoring_fn_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> ScoringFn: