diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 52508d2ec..aba597bf3 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -28,11 +28,37 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho # with standard metrics so they can be rolled up? @json_schema_type class ScoringFnParamsType(Enum): + """ + A type of scoring function parameters. + + :cvar llm_as_judge: Provide judge model and prompt template. + :cvar regex_parser: Provide regexes to parse the answer from the generated response. + :cvar basic: Parameters for basic non-parameterized scoring function. + """ + llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" basic = "basic" +@json_schema_type +class ScoringFunctionType(Enum): + """ + A type of scoring function. Each type is a criteria for evaluating answers. + + :cvar llm_as_judge: Scoring function that uses a judge model to score the answer. + :cvar regex_parser: Scoring function that parses the answer from the generated response using regexes, and checks against the expected answer. + """ + + llm_as_judge = "llm_as_judge" + regex_parser = "regex_parser" + # NOTE: add additional scoring function types that can be registered + # equality = "equality" + # subset_of = "subset_of" + # valid_json = "valid_json" + # text_quality = "text_quality" + + @json_schema_type class AggregationFunctionType(Enum): average = "average" @@ -43,7 +69,7 @@ class AggregationFunctionType(Enum): @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value + type: Literal["llm_as_judge"] = "llm_as_judge" judge_model: str prompt_template: Optional[str] = None judge_score_regexes: Optional[List[str]] = Field( @@ -58,9 +84,9 @@ class LLMAsJudgeScoringFnParams(BaseModel): @json_schema_type class RegexParserScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value + type: Literal["regex_parser"] = "regex_parser" parsing_regexes: Optional[List[str]] = Field( - description="Regex to extract the answer from generated response", + description="Regexes to extract the answer from generated response", default_factory=list, ) aggregation_functions: Optional[List[AggregationFunctionType]] = Field( @@ -71,7 +97,7 @@ class RegexParserScoringFnParams(BaseModel): @json_schema_type class BasicScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value + type: Literal["basic"] = "basic" aggregation_functions: Optional[List[AggregationFunctionType]] = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, @@ -140,10 +166,18 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: Optional[str] = None, - provider_id: Optional[str] = None, + scoring_fn_type: ScoringFunctionType, params: Optional[ScoringFnParams] = None, - ) -> None: ... + scoring_fn_id: Optional[str] = None, + description: Optional[str] = None, + ): + """ + Register a new scoring function with given parameters. + Only valid scoring function type that can be parameterized can be registered. + + :param scoring_fn_type: The type of scoring function to register. A function type can only be registered if it is a valid type. + :param params: The parameters for the scoring function. + :param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated. + :param description: (Optional) The description of the scoring function. + """ + ...