inclusion->subsetof

This commit is contained in:
Xi Yan 2024-10-24 16:13:49 -07:00
parent 29e48cc5c1
commit f6340a47d1
3 changed files with 7 additions and 7 deletions

View file

@ -15,13 +15,13 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.common import (
) )
class InclusionScorer(BaseScorer): class SubsetOfScorer(BaseScorer):
""" """
A scorer that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. A scorer that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
""" """
scoring_function_def = ScoringFunctionDef( scoring_function_def = ScoringFunctionDef(
identifier="inclusion", identifier="subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[], parameters=[],
return_type=NumberType(), return_type=NumberType(),

View file

@ -17,15 +17,15 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer i
EqualityScorer, EqualityScorer,
) )
from llama_stack.providers.impls.meta_reference.scoring.scorer.inclusion_scorer import ( from llama_stack.providers.impls.meta_reference.scoring.scorer.subset_of_scorer import (
InclusionScorer, SubsetOfScorer,
) )
from .config import MetaReferenceScoringConfig from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [ SUPPORTED_SCORERS = [
EqualityScorer, EqualityScorer,
InclusionScorer, SubsetOfScorer,
] ]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}

View file

@ -65,7 +65,7 @@ async def test_eval(eval_settings):
model="Llama3.1-8B-Instruct", model="Llama3.1-8B-Instruct",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
), ),
scoring_functions=["inclusion"], scoring_functions=["subset_of"],
) )
assert response.job_id == "0" assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id) job_status = await eval_impl.job_status(response.job_id)
@ -76,4 +76,4 @@ async def test_eval(eval_settings):
assert eval_response is not None assert eval_response is not None
assert len(eval_response.generations) == 5 assert len(eval_response.generations) == 5
assert "inclusion" in eval_response.scores assert "subset_of" in eval_response.scores