From cd3a3a5e263229cea2419fee19da3c23bdde378a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 11 Mar 2025 23:10:17 -0700 Subject: [PATCH] add alternative --- .../scoring_functions/scoring_functions.py | 79 ++++++++++++++++++- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 6724cbda7..ddfb720f2 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -12,8 +12,8 @@ from typing import ( Literal, Optional, Protocol, - Union, runtime_checkable, + Union, ) from pydantic import BaseModel, Field @@ -152,6 +152,75 @@ ScoringFnParams = register_schema( ) +# TODO(xiyan): ALTERNATIVE OPTION, merge ScoringFnParamsType + ScoringFunctionType +# @json_schema_type +# class LLMAsJudgeScoringFnParams(BaseModel): +# type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" +# judge_model: str +# prompt_template: Optional[str] = None +# judge_score_regexes: Optional[List[str]] = Field( +# description="Regexes to extract the answer from generated response", +# default_factory=list, +# ) +# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( +# description="Aggregation functions to apply to the scores of each row", +# default_factory=list, +# ) + + +# class RegexParserScoringFnParamsCommon(BaseModel): +# parsing_regexes: Optional[List[str]] = Field( +# description="Regexes to extract the answer from generated response", +# default_factory=list, +# ) +# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( +# description="Aggregation functions to apply to the scores of each row", +# default_factory=list, +# ) + + +# @json_schema_type +# class RegexParserScoringFnParams(RegexParserScoringFnParamsCommon): +# type: Literal["regex_parser"] = "regex_parser" + + +# @json_schema_type +# class RegexParserMathScoringFnParams(RegexParserScoringFnParamsCommon): +# type: Literal["regex_parser_math_response"] = "regex_parser_math_response" + + +# class BasicScoringFnParamsCommon(BaseModel): +# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( +# description="Aggregation functions to apply to the scores of each row", +# default_factory=list, +# ) + + +# @json_schema_type +# class EqualityScoringFnParams(BasicScoringFnParamsCommon): +# type: Literal["equality"] = "equality" + + +# @json_schema_type +# class SubsetOfcoringFnParams(BasicScoringFnParamsCommon): +# type: Literal["subset_of"] = "subset_of" + + +# ScoringFnParams = register_schema( +# Annotated[ +# Union[ +# LLMAsJudgeScoringFnParams, +# RegexParserScoringFnParams, +# RegexParserMathScoringFnParams, +# EqualityScoringFnParams, +# SubsetOfcoringFnParams, +# ], +# Field(discriminator="type"), +# ], +# name="ScoringFnParams", +# ) + + class CommonScoringFnFields(BaseModel): """ :param scoring_fn_type: The type of scoring function. @@ -172,7 +241,9 @@ class CommonScoringFnFields(BaseModel): @json_schema_type class ScoringFn(CommonScoringFnFields, Resource): - type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value + type: Literal[ResourceType.scoring_function.value] = ( + ResourceType.scoring_function.value + ) @property def scoring_fn_id(self) -> str: @@ -199,7 +270,9 @@ class ScoringFunctions(Protocol): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") - async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... + async def get_scoring_function( + self, scoring_fn_id: str, / + ) -> Optional[ScoringFn]: ... @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function(