migrate scoring fns to resource (#422)

* fix after rebase

* remove print

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:28:48 -08:00 committed by GitHub
parent 3802edfc50
commit 0a3b3d5fb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 113 additions and 62 deletions

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol):
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
async def list_scoring_functions(self) -> List[ScoringFn]: ...
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class EvalTasksProtocolPrivate(Protocol):

View file

@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
answer_correctness_fn_def = ScoringFnDef(
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
params=None,
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
factuality_fn_def = ScoringFnDef(
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
params=None,
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
)

View file

@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:

View file

@ -24,15 +24,15 @@ class BaseScoringFn(ABC):
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
@abstractmethod
async def score_row(

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
equality = ScoringFnDef(
equality = ScoringFn(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="meta-reference",
provider_resource_id="equality",
return_type=NumberType(),
)

View file

@ -5,11 +5,13 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
llm_as_judge_base = ScoringFnDef(
llm_as_judge_base = ScoringFn(
identifier="meta-reference::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="llm-as-judge-base",
)

View file

@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFnDef(
regex_parser_multiple_choice_answer = ScoringFn(
identifier="meta-reference::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
subset_of = ScoringFnDef(
subset_of = ScoringFn(
identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="subset-of",
)

View file

@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn):
assert (
fn_def.params is not None
and fn_def.params.type == ScoringConfigType.regex_parser.value
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"]

View file

@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
async def scoring_stack(request, inference_model):
fixture_dict = request.param
providers = {}
@ -65,4 +65,18 @@ async def scoring_stack(request):
provider_data,
)
provider_id = providers["inference"][0].provider_id
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-405B-Instruct",
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-8B-Instruct",
provider_id=provider_id,
)
return impls