unwrap context -> config

This commit is contained in:
Xi Yan 2024-11-05 16:02:47 -08:00
parent 04eebd8a36
commit be0649d79d

View file

@ -26,15 +26,15 @@ from llama_stack.apis.common.type_system import ParamType
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@json_schema_type @json_schema_type
class ScoringContextType(Enum): class ScoringConfigType(Enum):
llm_as_judge = "llm_as_judge" llm_as_judge = "llm_as_judge"
answer_parsing = "answer_parsing" answer_parsing = "answer_parsing"
@json_schema_type @json_schema_type
class LLMAsJudgeContext(BaseModel): class LLMAsJudgeScoringFnConfig(BaseModel):
type: Literal[ScoringContextType.llm_as_judge.value] = ( # type: ignore type: Literal[ScoringContextType.llm_as_judge.value] = ( # type: ignore
ScoringContextType.llm_as_judge.value ScoringConfigType.llm_as_judge.value
) )
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
@ -42,9 +42,9 @@ class LLMAsJudgeContext(BaseModel):
@json_schema_type @json_schema_type
class AnswerParsingContext(BaseModel): class AnswerParsingScoringFnConfig(BaseModel):
type: Literal[ScoringContextType.answer_parsing.value] = ( # type: ignore type: Literal[ScoringContextType.answer_parsing.value] = ( # type: ignore
ScoringContextType.answer_parsing.value ScoringConfigType.answer_parsing.value
) )
parsing_regex: Optional[List[str]] = Field( parsing_regex: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response", description="Regex to extract the answer from generated response",
@ -52,20 +52,15 @@ class AnswerParsingContext(BaseModel):
) )
ScoringContext = Annotated[ ScoringFnConfig = Annotated[
Union[ Union[
LLMAsJudgeContext, LLMAsJudgeScoringFnConfig,
AnswerParsingContext, AnswerParsingScoringFnConfig,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@json_schema_type
class ScoringFnConfig(BaseModel):
scoring_context: ScoringContext # type: ignore
@json_schema_type @json_schema_type
class ScoringFnDef(BaseModel): class ScoringFnDef(BaseModel):
identifier: str identifier: str
@ -77,7 +72,7 @@ class ScoringFnDef(BaseModel):
return_type: ParamType = Field( return_type: ParamType = Field(
description="The return type of the deterministic function", description="The return type of the deterministic function",
) )
context: Optional[ScoringContext] = None # type: ignore config: Optional[ScoringFnConfig] = None # type: ignore
# We can optionally add information here to support packaging of code, etc. # We can optionally add information here to support packaging of code, etc.