remove aggregation functions

This commit is contained in:
Xi Yan 2025-03-23 16:17:09 -07:00
parent 64388de068
commit 2723b05164
3 changed files with 7 additions and 262 deletions

View file

@ -13,8 +13,8 @@ from typing import (
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
Union,
)
from pydantic import BaseModel, Field
@ -63,35 +63,14 @@ class GraderTypeInfo(BaseModel):
)
class AggregationFunctionType(Enum):
"""
A type of aggregation function.
:cvar average: Average the scores of each row.
:cvar median: Median the scores of each row.
:cvar categorical_count: Count the number of rows that match each category.
:cvar accuracy: Number of correct results over total results.
"""
average = "average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
class BasicGraderParams(BaseModel):
aggregation_functions: List[AggregationFunctionType]
class LlmGraderParams(BaseModel):
model: str
prompt: str
score_regexes: List[str]
aggregation_functions: List[AggregationFunctionType]
class RegexParserGraderParams(BaseModel):
parsing_regexes: List[str]
aggregation_functions: List[AggregationFunctionType]
@json_schema_type
@ -109,25 +88,21 @@ class RegexParserGrader(BaseModel):
@json_schema_type
class EqualityGrader(BaseModel):
type: Literal["equality"] = "equality"
equality: BasicGraderParams
@json_schema_type
class SubsetOfGrader(BaseModel):
type: Literal["subset_of"] = "subset_of"
subset_of: BasicGraderParams
@json_schema_type
class FactualityGrader(BaseModel):
type: Literal["factuality"] = "factuality"
factuality: BasicGraderParams
@json_schema_type
class FaithfulnessGrader(BaseModel):
type: Literal["faithfulness"] = "faithfulness"
faithfulness: BasicGraderParams
GraderDefinition = register_schema(