mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
wip scoring refactor
This commit is contained in:
parent
8a74e400d6
commit
bf8bc7a781
11 changed files with 137 additions and 13 deletions
|
@ -102,8 +102,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
if scoring_fn_id not in SCORER_REGISTRY:
|
if scoring_fn_id not in SCORER_REGISTRY:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
||||||
score_results = scoring_fn.score(input_rows)
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||||
agg_results = scoring_fn.aggregate(score_results)
|
agg_results = await scoring_fn.aggregate(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -26,12 +26,23 @@ class BaseScoringFn(ABC):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
async def score_row(
|
||||||
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
|
) -> ScoringResultRow:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
|
async def score(
|
||||||
return [self.score_row(input_row) for input_row in input_rows]
|
self,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
) -> List[ScoringResultRow]:
|
||||||
|
return [
|
||||||
|
await self.score_row(input_row, scoring_fn_identifier)
|
||||||
|
for input_row in input_rows
|
||||||
|
]
|
||||||
|
|
|
@ -27,7 +27,11 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "equality",
|
||||||
|
) -> ScoringResultRow:
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
assert (
|
assert (
|
||||||
"generated_answer" in input_row
|
"generated_answer" in input_row
|
||||||
|
@ -40,5 +44,7 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
|
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
return aggregate_accuracy(scoring_results)
|
return aggregate_accuracy(scoring_results)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
|
BaseScoringFn,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_accuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
JUDGE_PROMPT = """
|
||||||
|
You will be given a question, a expected_answer, and a system_answer.
|
||||||
|
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||||
|
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||||
|
Provide your feedback as follows:
|
||||||
|
Feedback:::
|
||||||
|
Total rating: (your rating, as a int between 0 and 5)
|
||||||
|
Now here are the question, expected_answer, system_answer.
|
||||||
|
Question: {question}
|
||||||
|
Expected Answer: {expected_answer}
|
||||||
|
System Answer: {answer}
|
||||||
|
Feedback:::
|
||||||
|
Total rating:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that assigns
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.scoring_fn_def_registry = {}
|
||||||
|
|
||||||
|
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||||
|
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||||
|
|
||||||
|
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||||
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
|
assert (
|
||||||
|
"generated_answer" in input_row
|
||||||
|
), "Generated answer not found in input row."
|
||||||
|
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return aggregate_accuracy(scoring_results)
|
|
@ -27,7 +27,11 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||||
|
) -> ScoringResultRow:
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
assert (
|
assert (
|
||||||
"generated_answer" in input_row
|
"generated_answer" in input_row
|
||||||
|
@ -40,5 +44,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
|
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
return aggregate_accuracy(scoring_results)
|
return aggregate_accuracy(scoring_results)
|
||||||
|
|
18
llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py
vendored
Normal file
18
llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py
vendored
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: BraintrustScoringConfig, _deps) -> Any:
|
||||||
|
pass
|
||||||
|
# from .braintrust import VLLMInferenceImpl
|
||||||
|
|
||||||
|
# impl = VLLMInferenceImpl(config)
|
||||||
|
# await impl.initialize()
|
||||||
|
# return impl
|
9
llama_stack/providers/impls/third_party/scoring/braintrust/config.py
vendored
Normal file
9
llama_stack/providers/impls/third_party/scoring/braintrust/config.py
vendored
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.eval import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringConfig(BaseModel): ...
|
|
@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
|
Api.inference,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,3 +7,8 @@ providers:
|
||||||
- provider_id: test-meta
|
- provider_id: test-meta
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
inference:
|
||||||
|
- provider_id: tgi0
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5009
|
||||||
|
|
|
@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def scoring_settings():
|
async def scoring_settings():
|
||||||
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.scoring, deps=[Api.datasetio, Api.inference]
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"scoring_impl": impls[Api.scoring],
|
"scoring_impl": impls[Api.scoring],
|
||||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||||
|
@ -62,8 +64,9 @@ async def test_scoring_score(scoring_settings):
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
scoring_functions=["equality"],
|
scoring_functions=["equality", "subset_of"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.results) == 1
|
assert len(response.results) == 2
|
||||||
assert "equality" in response.results
|
assert "equality" in response.results
|
||||||
|
assert "subset_of" in response.results
|
||||||
|
|
|
@ -33,6 +33,10 @@ providers:
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:5009
|
url: http://127.0.0.1:5009
|
||||||
|
- provider_id: tgi1
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5010
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue