Merge branch 'evals_8' into evals_9

This commit is contained in:
Xi Yan 2024-10-28 11:43:04 -07:00
commit 488b967d33
8 changed files with 20 additions and 16 deletions

View file

@ -49,12 +49,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
for x in FIXED_FNS: for x in FIXED_FNS:
impl = x() impl = x()
await impl.initialize()
for fn_defs in impl.get_supported_scoring_fn_defs(): for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS: for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api) impl = x(inference_api=self.inference_api)
await impl.initialize()
for fn_defs in impl.get_supported_scoring_fn_defs(): for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl self.llm_as_judge_fn = impl
@ -62,12 +60,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]: async def list_scoring_functions(self) -> List[ScoringFnDef]:
return [ scoring_fn_defs_list = [
fn_defs fn_def
for impl in self.scoring_fn_id_impls.values() for impl in self.scoring_fn_id_impls.values()
for fn_defs in impl.get_supported_scoring_fn_defs() for fn_def in impl.get_supported_scoring_fn_defs()
] ]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"meta-reference"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def) self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn

View file

@ -24,8 +24,6 @@ class BaseScoringFn(ABC):
def __str__(self) -> str: def __str__(self) -> str:
return self.__class__.__name__ return self.__class__.__name__
async def initialize(self) -> None: ...
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
return [x for x in self.supported_fn_defs_registry.values()] return [x for x in self.supported_fn_defs_registry.values()]

View file

@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
) )
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
equality_fn_def, equality,
) )
@ -28,7 +28,7 @@ class EqualityScoringFn(BaseScoringFn):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
equality_fn_def.identifier: equality_fn_def, equality.identifier: equality,
} }
async def score_row( async def score_row(

View file

@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.scoring_functions import ScoringFnDef
equality_fn_def = ScoringFnDef( equality = ScoringFnDef(
identifier="meta-reference::equality", identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[], parameters=[],

View file

@ -22,7 +22,8 @@ System Answer: {generated_answer}
Feedback::: Feedback:::
Total rating: Total rating:
""" """
llm_as_judge_8b_correctness_fn_def = ScoringFnDef(
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness", identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function", description="Llm As Judge Scoring Function",
parameters=[], parameters=[],

View file

@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.scoring_functions import ScoringFnDef
subset_of_fn_def = ScoringFnDef( subset_of = ScoringFnDef(
identifier="meta-reference::subset_of", identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[], parameters=[],

View file

@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
aggregate_average, aggregate_average,
) )
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
llm_as_judge_8b_correctness_fn_def, llm_as_judge_8b_correctness,
) )
@ -29,7 +29,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
super().__init__(*arg, **kwargs) super().__init__(*arg, **kwargs)
self.inference_api = inference_api self.inference_api = inference_api
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
llm_as_judge_8b_correctness_fn_def.identifier: llm_as_judge_8b_correctness_fn_def, llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
} }
async def score_row( async def score_row(

View file

@ -15,7 +15,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
) )
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
subset_of_fn_def, subset_of,
) )
@ -27,7 +27,7 @@ class SubsetOfScoringFn(BaseScoringFn):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
subset_of_fn_def.identifier: subset_of_fn_def, subset_of.identifier: subset_of,
} }
async def score_row( async def score_row(