mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
wip scoring refactor
This commit is contained in:
parent
8a74e400d6
commit
bf8bc7a781
11 changed files with 137 additions and 13 deletions
|
@ -102,8 +102,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
if scoring_fn_id not in SCORER_REGISTRY:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
||||
score_results = scoring_fn.score(input_rows)
|
||||
agg_results = scoring_fn.aggregate(score_results)
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -26,12 +26,23 @@ class BaseScoringFn(ABC):
|
|||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
|
||||
return [self.score_row(input_row) for input_row in input_rows]
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
return [
|
||||
await self.score_row(input_row, scoring_fn_identifier)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
|
|
@ -27,7 +27,11 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
|
@ -40,5 +44,7 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
"score": score,
|
||||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
JUDGE_PROMPT = """
|
||||
You will be given a question, a expected_answer, and a system_answer.
|
||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||
Provide your feedback as follows:
|
||||
Feedback:::
|
||||
Total rating: (your rating, as a int between 0 and 5)
|
||||
Now here are the question, expected_answer, system_answer.
|
||||
Question: {question}
|
||||
Expected Answer: {expected_answer}
|
||||
System Answer: {answer}
|
||||
Feedback:::
|
||||
Total rating:
|
||||
"""
|
||||
|
||||
|
||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scoring_fn_def_registry = {}
|
||||
|
||||
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
|
||||
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
|
@ -27,7 +27,11 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
|
@ -40,5 +44,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
"score": score,
|
||||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
18
llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py
vendored
Normal file
18
llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py
vendored
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraintrustScoringConfig, _deps) -> Any:
|
||||
pass
|
||||
# from .braintrust import VLLMInferenceImpl
|
||||
|
||||
# impl = VLLMInferenceImpl(config)
|
||||
# await impl.initialize()
|
||||
# return impl
|
9
llama_stack/providers/impls/third_party/scoring/braintrust/config.py
vendored
Normal file
9
llama_stack/providers/impls/third_party/scoring/braintrust/config.py
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.eval import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel): ...
|
|
@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
Api.inference,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,3 +7,8 @@ providers:
|
|||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
inference:
|
||||
- provider_id: tgi0
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:5009
|
||||
|
|
|
@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_settings():
|
||||
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.scoring, deps=[Api.datasetio, Api.inference]
|
||||
)
|
||||
return {
|
||||
"scoring_impl": impls[Api.scoring],
|
||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||
|
@ -62,8 +64,9 @@ async def test_scoring_score(scoring_settings):
|
|||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=["equality"],
|
||||
scoring_functions=["equality", "subset_of"],
|
||||
)
|
||||
|
||||
assert len(response.results) == 1
|
||||
assert len(response.results) == 2
|
||||
assert "equality" in response.results
|
||||
assert "subset_of" in response.results
|
||||
|
|
|
@ -33,6 +33,10 @@ providers:
|
|||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:5009
|
||||
- provider_id: tgi1
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:5010
|
||||
memory:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue