mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Merge branch 'evals_8' into evals_9
This commit is contained in:
commit
488b967d33
8 changed files with 20 additions and 16 deletions
|
@ -49,12 +49,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
for x in FIXED_FNS:
|
for x in FIXED_FNS:
|
||||||
impl = x()
|
impl = x()
|
||||||
await impl.initialize()
|
|
||||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
for x in LLM_JUDGE_FNS:
|
for x in LLM_JUDGE_FNS:
|
||||||
impl = x(inference_api=self.inference_api)
|
impl = x(inference_api=self.inference_api)
|
||||||
await impl.initialize()
|
|
||||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
self.llm_as_judge_fn = impl
|
self.llm_as_judge_fn = impl
|
||||||
|
@ -62,12 +60,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
return [
|
scoring_fn_defs_list = [
|
||||||
fn_defs
|
fn_def
|
||||||
for impl in self.scoring_fn_id_impls.values()
|
for impl in self.scoring_fn_id_impls.values()
|
||||||
for fn_defs in impl.get_supported_scoring_fn_defs()
|
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for f in scoring_fn_defs_list:
|
||||||
|
assert f.identifier.startswith(
|
||||||
|
"meta-reference"
|
||||||
|
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
|
||||||
|
|
||||||
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||||
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
||||||
|
|
|
@ -24,8 +24,6 @@ class BaseScoringFn(ABC):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
|
||||||
|
|
||||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||||
return [x for x in self.supported_fn_defs_registry.values()]
|
return [x for x in self.supported_fn_defs_registry.values()]
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||||
equality_fn_def,
|
equality,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
equality_fn_def.identifier: equality_fn_def,
|
equality.identifier: equality,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
|
|
@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
equality_fn_def = ScoringFnDef(
|
equality = ScoringFnDef(
|
||||||
identifier="meta-reference::equality",
|
identifier="meta-reference::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
|
|
|
@ -22,7 +22,8 @@ System Answer: {generated_answer}
|
||||||
Feedback:::
|
Feedback:::
|
||||||
Total rating:
|
Total rating:
|
||||||
"""
|
"""
|
||||||
llm_as_judge_8b_correctness_fn_def = ScoringFnDef(
|
|
||||||
|
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
|
|
|
@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
subset_of_fn_def = ScoringFnDef(
|
subset_of = ScoringFnDef(
|
||||||
identifier="meta-reference::subset_of",
|
identifier="meta-reference::subset_of",
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
|
||||||
aggregate_average,
|
aggregate_average,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||||
llm_as_judge_8b_correctness_fn_def,
|
llm_as_judge_8b_correctness,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
super().__init__(*arg, **kwargs)
|
super().__init__(*arg, **kwargs)
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
llm_as_judge_8b_correctness_fn_def.identifier: llm_as_judge_8b_correctness_fn_def,
|
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||||
subset_of_fn_def,
|
subset_of,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
subset_of_fn_def.identifier: subset_of_fn_def,
|
subset_of.identifier: subset_of,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue