From be0649d79d7113dfe22a02cf7383c289d5258167 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 5 Nov 2024 16:02:47 -0800 Subject: [PATCH] unwrap context -> config --- .../scoring_functions/scoring_functions.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 742b1d88f..4bb7ea187 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -26,15 +26,15 @@ from llama_stack.apis.common.type_system import ParamType # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? @json_schema_type -class ScoringContextType(Enum): +class ScoringConfigType(Enum): llm_as_judge = "llm_as_judge" answer_parsing = "answer_parsing" @json_schema_type -class LLMAsJudgeContext(BaseModel): +class LLMAsJudgeScoringFnConfig(BaseModel): type: Literal[ScoringContextType.llm_as_judge.value] = ( # type: ignore - ScoringContextType.llm_as_judge.value + ScoringConfigType.llm_as_judge.value ) judge_model: str prompt_template: Optional[str] = None @@ -42,9 +42,9 @@ class LLMAsJudgeContext(BaseModel): @json_schema_type -class AnswerParsingContext(BaseModel): +class AnswerParsingScoringFnConfig(BaseModel): type: Literal[ScoringContextType.answer_parsing.value] = ( # type: ignore - ScoringContextType.answer_parsing.value + ScoringConfigType.answer_parsing.value ) parsing_regex: Optional[List[str]] = Field( description="Regex to extract the answer from generated response", @@ -52,20 +52,15 @@ class AnswerParsingContext(BaseModel): ) -ScoringContext = Annotated[ +ScoringFnConfig = Annotated[ Union[ - LLMAsJudgeContext, - AnswerParsingContext, + LLMAsJudgeScoringFnConfig, + AnswerParsingScoringFnConfig, ], Field(discriminator="type"), ] -@json_schema_type -class ScoringFnConfig(BaseModel): - scoring_context: ScoringContext # type: ignore - - @json_schema_type class ScoringFnDef(BaseModel): identifier: str @@ -77,7 +72,7 @@ class ScoringFnDef(BaseModel): return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[ScoringContext] = None # type: ignore + config: Optional[ScoringFnConfig] = None # type: ignore # We can optionally add information here to support packaging of code, etc.