From f1a2548ad5b84f62e0902721c61faae18a2760bd Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 31 Oct 2024 15:26:15 -0700 Subject: [PATCH] scoring fn api update --- .../scoring_functions/scoring_functions.py | 48 +++++++++++++------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 2e5bf0aef..597a1abbe 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -3,27 +3,31 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType -@json_schema_type -class Parameter(BaseModel): - name: str - type: ParamType - description: Optional[str] = None - - # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? +@json_schema_type +class ScoringContextType(Enum): + llm_as_judge = "llm_as_judge" + answer_parsing = "answer_parsing" + + +@json_schema_type class LLMAsJudgeContext(BaseModel): + type: Literal[ScoringContextType.llm_as_judge.value] = ( + ScoringContextType.llm_as_judge.value + ) judge_model: str prompt_template: Optional[str] = None judge_score_regex: Optional[List[str]] = Field( @@ -32,6 +36,26 @@ class LLMAsJudgeContext(BaseModel): ) +@json_schema_type +class AnswerParsingContext(BaseModel): + type: Literal[ScoringContextType.answer_parsing.value] = ( + ScoringContextType.answer_parsing.value + ) + parsing_regex: Optional[List[str]] = Field( + "Regex to extract the answer from generated response", + default_factory=list, + ) + + +ScoringContext = Annotated[ + Union[ + LLMAsJudgeContext, + AnswerParsingContext, + ], + Field(discriminator="type"), +] + + @json_schema_type class ScoringFnDef(BaseModel): identifier: str @@ -40,14 +64,10 @@ class ScoringFnDef(BaseModel): default_factory=dict, description="Any additional metadata for this definition", ) - parameters: List[Parameter] = Field( - description="List of parameters for the deterministic function", - default_factory=list, - ) return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[LLMAsJudgeContext] = None + context: Optional[ScoringContext] = None # We can optionally add information here to support packaging of code, etc.