From 8627e27b177629fa7e7c96a18d7528d225711162 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 11:35:08 -0700 Subject: [PATCH 1/3] remove initialize --- llama_stack/providers/impls/meta_reference/scoring/scoring.py | 2 -- .../impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 3168a5282..b037e5359 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -49,12 +49,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def initialize(self) -> None: for x in FIXED_FNS: impl = x() - await impl.initialize() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl for x in LLM_JUDGE_FNS: impl = x(inference_api=self.inference_api) - await impl.initialize() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl self.llm_as_judge_fn = impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index 52b48e5bb..cbd875be6 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -24,8 +24,6 @@ class BaseScoringFn(ABC): def __str__(self) -> str: return self.__class__.__name__ - async def initialize(self) -> None: ... - def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: return [x for x in self.supported_fn_defs_registry.values()] From e3f80fa4aa30ac17d28c89cc81c3d08f075c61f0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 11:38:14 -0700 Subject: [PATCH 2/3] address nits --- llama_stack/providers/impls/meta_reference/scoring/scoring.py | 4 ++-- .../meta_reference/scoring/scoring_fn/equality_scoring_fn.py | 4 ++-- .../meta_reference/scoring/scoring_fn/fn_defs/equality.py | 2 +- .../scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py | 3 ++- .../meta_reference/scoring/scoring_fn/fn_defs/subset_of.py | 2 +- .../scoring/scoring_fn/llm_as_judge_scoring_fn.py | 4 ++-- .../meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py | 4 ++-- 7 files changed, 12 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index b037e5359..8797e898d 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -61,9 +61,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def list_scoring_functions(self) -> List[ScoringFnDef]: return [ - fn_defs + fn_def for impl in self.scoring_fn_id_impls.values() - for fn_defs in impl.get_supported_scoring_fn_defs() + for fn_def in impl.get_supported_scoring_fn_defs() ] async def register_scoring_function(self, function_def: ScoringFnDef) -> None: diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 1b8d531aa..556436286 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ) from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( - equality_fn_def, + equality, ) @@ -28,7 +28,7 @@ class EqualityScoringFn(BaseScoringFn): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = { - equality_fn_def.identifier: equality_fn_def, + equality.identifier: equality, } async def score_row( diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py index cdc4fdc81..99fa6cc3a 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.scoring_functions import ScoringFnDef -equality_fn_def = ScoringFnDef( +equality = ScoringFnDef( identifier="meta-reference::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", parameters=[], diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 215f4649e..20a67edc7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -22,7 +22,8 @@ System Answer: {generated_answer} Feedback::: Total rating: """ -llm_as_judge_8b_correctness_fn_def = ScoringFnDef( + +llm_as_judge_8b_correctness = ScoringFnDef( identifier="meta-reference::llm_as_judge_8b_correctness", description="Llm As Judge Scoring Function", parameters=[], diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py index c3cf8e960..5a3e2e8fb 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py @@ -8,7 +8,7 @@ from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.scoring_functions import ScoringFnDef -subset_of_fn_def = ScoringFnDef( +subset_of = ScoringFnDef( identifier="meta-reference::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", parameters=[], diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index cc8e04048..5a5ce2550 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -16,7 +16,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import aggregate_average, ) from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( - llm_as_judge_8b_correctness_fn_def, + llm_as_judge_8b_correctness, ) @@ -29,7 +29,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn): super().__init__(*arg, **kwargs) self.inference_api = inference_api self.supported_fn_defs_registry = { - llm_as_judge_8b_correctness_fn_def.identifier: llm_as_judge_8b_correctness_fn_def, + llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness, } async def score_row( diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index 394aa8177..fcef2ead7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -15,7 +15,7 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ) from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( - subset_of_fn_def, + subset_of, ) @@ -27,7 +27,7 @@ class SubsetOfScoringFn(BaseScoringFn): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = { - subset_of_fn_def.identifier: subset_of_fn_def, + subset_of.identifier: subset_of, } async def score_row( From 3a4039aea9dbe588e9dbcd3a9cf6c48aaf9dfe86 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 11:42:07 -0700 Subject: [PATCH 3/3] check identifier prefix --- .../providers/impls/meta_reference/scoring/scoring.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 8797e898d..41b24a512 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -60,12 +60,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - return [ + scoring_fn_defs_list = [ fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() ] + for f in scoring_fn_defs_list: + assert f.identifier.startswith( + "meta-reference" + ), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " + + return scoring_fn_defs_list + async def register_scoring_function(self, function_def: ScoringFnDef) -> None: self.llm_as_judge_fn.register_scoring_fn_def(function_def) self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn