add alternative

This commit is contained in:
Xi Yan 2025-03-11 23:10:17 -07:00
parent 4236769b65
commit cd3a3a5e26

View file

@ -12,8 +12,8 @@ from typing import (
Literal, Literal,
Optional, Optional,
Protocol, Protocol,
Union,
runtime_checkable, runtime_checkable,
Union,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -152,6 +152,75 @@ ScoringFnParams = register_schema(
) )
# TODO(xiyan): ALTERNATIVE OPTION, merge ScoringFnParamsType + ScoringFunctionType
# @json_schema_type
# class LLMAsJudgeScoringFnParams(BaseModel):
# type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
# judge_model: str
# prompt_template: Optional[str] = None
# judge_score_regexes: Optional[List[str]] = Field(
# description="Regexes to extract the answer from generated response",
# default_factory=list,
# )
# aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
# description="Aggregation functions to apply to the scores of each row",
# default_factory=list,
# )
# class RegexParserScoringFnParamsCommon(BaseModel):
# parsing_regexes: Optional[List[str]] = Field(
# description="Regexes to extract the answer from generated response",
# default_factory=list,
# )
# aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
# description="Aggregation functions to apply to the scores of each row",
# default_factory=list,
# )
# @json_schema_type
# class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon):
# type: Literal["regex_parser"] = "regex_parser"
# @json_schema_type
# class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon):
# type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
# class BasicScoringFnParamsCommon(BaseModel):
# aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
# description="Aggregation functions to apply to the scores of each row",
# default_factory=list,
# )
# @json_schema_type
# class EqualityScoringFnParams(BasicScoringFnParamsCommon):
# type: Literal["equality"] = "equality"
# @json_schema_type
# class SubsetOfcoringFnParams(BasicScoringFnParamsCommon):
# type: Literal["subset_of"] = "subset_of"
# ScoringFnParams = register_schema(
# Annotated[
# Union[
# LLMAsJudgeScoringFnParams,
# RegexParserScoringFnParams,
# RegexParserMathScoringFnParams,
# EqualityScoringFnParams,
# SubsetOfcoringFnParams,
# ],
# Field(discriminator="type"),
# ],
# name="ScoringFnParams",
# )
class CommonScoringFnFields(BaseModel): class CommonScoringFnFields(BaseModel):
""" """
:param scoring_fn_type: The type of scoring function. :param scoring_fn_type: The type of scoring function.
@ -172,7 +241,9 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type @json_schema_type
class ScoringFn(CommonScoringFnFields, Resource): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property @property
def scoring_fn_id(self) -> str: def scoring_fn_id(self) -> str:
@ -199,7 +270,9 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(