This commit is contained in:
Xi Yan 2025-03-11 22:45:48 -07:00
parent 11e57e17e6
commit f9ea90c4f7
3 changed files with 90 additions and 20 deletions

View file

@ -12,8 +12,8 @@ from typing import (
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
Union,
)
from pydantic import BaseModel, Field
@ -65,6 +65,15 @@ class ScoringFunctionType(Enum):
@json_schema_type
class AggregationFunctionType(Enum):
"""
A type of aggregation function.
:cvar average: Average the scores of each row.
:cvar median: Median the scores of each row.
:cvar categorical_count: Count the number of rows that match each category.
:cvar accuracy: Number of correct results over total results.
"""
average = "average"
median = "median"
categorical_count = "categorical_count"
@ -73,6 +82,15 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""
Parameters for a scoring function that uses a judge model to score the answer.
:param judge_model: The model to use for scoring.
:param prompt_template: (Optional) The prompt template to use for scoring.
:param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
@ -88,6 +106,13 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""
Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
:param parsing_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["regex_parser"] = "regex_parser"
parsing_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
@ -101,6 +126,12 @@ class RegexParserScoringFnParams(BaseModel):
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""
Parameters for a non-parameterized scoring function.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["basic"] = "basic"
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
@ -135,7 +166,9 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property
def scoring_fn_id(self) -> str:
@ -162,7 +195,9 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(