mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-10 07:35:59 +00:00
migrate scoring fns to resource (#422)
* fix after rebase * remove print --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
3802edfc50
commit
0a3b3d5fb6
16 changed files with 113 additions and 62 deletions
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset
|
|||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
|
||||
|
@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||
|
||||
|
||||
class EvalTasksProtocolPrivate(Protocol):
|
||||
|
|
|
@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
|
@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||
raise NotImplementedError(
|
||||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
answer_correctness_fn_def = ScoringFnDef(
|
||||
answer_correctness_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
parameters=[],
|
||||
params=None,
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
factuality_fn_def = ScoringFnDef(
|
||||
factuality_fn_def = ScoringFn(
|
||||
identifier="braintrust::factuality",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
parameters=[],
|
||||
params=None,
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="factuality",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
|
@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
|
|
|
@ -24,15 +24,15 @@ class BaseScoringFn(ABC):
|
|||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||
return [x for x in self.supported_fn_defs_registry.values()]
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(
|
||||
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
|
||||
)
|
||||
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
|
|
|
@ -5,11 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
equality = ScoringFnDef(
|
||||
equality = ScoringFn(
|
||||
identifier="meta-reference::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
params=None,
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
llm_as_judge_base = ScoringFnDef(
|
||||
llm_as_judge_base = ScoringFn(
|
||||
identifier="meta-reference::llm_as_judge_base",
|
||||
description="Llm As Judge Scoring Function",
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="llm-as-judge-base",
|
||||
)
|
||||
|
|
|
@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
|||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
)
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFnDef(
|
||||
regex_parser_multiple_choice_answer = ScoringFn(
|
||||
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
subset_of = ScoringFnDef(
|
||||
subset_of = ScoringFn(
|
||||
identifier="meta-reference::subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="subset-of",
|
||||
)
|
||||
|
|
|
@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn):
|
|||
|
||||
assert (
|
||||
fn_def.params is not None
|
||||
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
||||
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
|
||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
|
|
@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_stack(request):
|
||||
async def scoring_stack(request, inference_model):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -65,4 +65,18 @@ async def scoring_stack(request):
|
|||
provider_data,
|
||||
)
|
||||
|
||||
provider_id = providers["inference"][0].provider_id
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
await impls[Api.models].register_model(
|
||||
model_id="Llama3.1-405B-Instruct",
|
||||
provider_id=provider_id,
|
||||
)
|
||||
await impls[Api.models].register_model(
|
||||
model_id="Llama3.1-8B-Instruct",
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue