From fe91608321ea9058743983f5468e4815a688c81e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 5 Nov 2024 14:34:56 -0800 Subject: [PATCH] scoring fn api --- .../scoring_functions/scoring_functions.py | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index d0a9cc597..742b1d88f 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,32 +4,66 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from enum import Enum +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType -@json_schema_type -class Parameter(BaseModel): - name: str - type: ParamType - description: Optional[str] = None - - # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? +@json_schema_type +class ScoringContextType(Enum): + llm_as_judge = "llm_as_judge" + answer_parsing = "answer_parsing" +@json_schema_type class LLMAsJudgeContext(BaseModel): + type: Literal[ScoringContextType.llm_as_judge.value] = ( # type: ignore + ScoringContextType.llm_as_judge.value + ) judge_model: str prompt_template: Optional[str] = None - judge_score_regex: Optional[List[str]] = Field( - description="Regex to extract the score from the judge response", - default=None, + judge_score_regex: Optional[List[str]] = Field() + + +@json_schema_type +class AnswerParsingContext(BaseModel): + type: Literal[ScoringContextType.answer_parsing.value] = ( # type: ignore + ScoringContextType.answer_parsing.value ) + parsing_regex: Optional[List[str]] = Field( + description="Regex to extract the answer from generated response", + default_factory=list, + ) + + +ScoringContext = Annotated[ + Union[ + LLMAsJudgeContext, + AnswerParsingContext, + ], + Field(discriminator="type"), +] + + +@json_schema_type +class ScoringFnConfig(BaseModel): + scoring_context: ScoringContext # type: ignore @json_schema_type @@ -40,14 +74,10 @@ class ScoringFnDef(BaseModel): default_factory=dict, description="Any additional metadata for this definition", ) - parameters: List[Parameter] = Field( - description="List of parameters for the deterministic function", - default_factory=list, - ) return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[LLMAsJudgeContext] = None + context: Optional[ScoringContext] = None # type: ignore # We can optionally add information here to support packaging of code, etc.