wip scoring refactor

This commit is contained in:
Xi Yan 2024-10-25 15:03:03 -07:00
parent 8a74e400d6
commit bf8bc7a781
11 changed files with 137 additions and 13 deletions

View file

@ -102,8 +102,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
if scoring_fn_id not in SCORER_REGISTRY: if scoring_fn_id not in SCORER_REGISTRY:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = SCORER_REGISTRY[scoring_fn_id]() scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
score_results = scoring_fn.score(input_rows) score_results = await scoring_fn.score(input_rows, scoring_fn_id)
agg_results = scoring_fn.aggregate(score_results) agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -26,12 +26,23 @@ class BaseScoringFn(ABC):
return self.__class__.__name__ return self.__class__.__name__
@abstractmethod @abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
raise NotImplementedError() raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: async def score(
return [self.score_row(input_row) for input_row in input_rows] self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier)
for input_row in input_rows
]

View file

@ -27,7 +27,11 @@ class EqualityScoringFn(BaseScoringFn):
return_type=NumberType(), return_type=NumberType(),
) )
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert (
"generated_answer" in input_row "generated_answer" in input_row
@ -40,5 +44,7 @@ class EqualityScoringFn(BaseScoringFn):
"score": score, "score": score,
} }
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results) return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {question}
Expected Answer: {expected_answer}
System Answer: {answer}
Feedback:::
Total rating:
"""
class LlmAsJudgeScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.scoring_fn_def_registry = {}
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None:
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -27,7 +27,11 @@ class SubsetOfScoringFn(BaseScoringFn):
return_type=NumberType(), return_type=NumberType(),
) )
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert (
"generated_answer" in input_row "generated_answer" in input_row
@ -40,5 +44,7 @@ class SubsetOfScoringFn(BaseScoringFn):
"score": score, "score": score,
} }
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results) return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import BraintrustScoringConfig
async def get_provider_impl(config: BraintrustScoringConfig, _deps) -> Any:
pass
# from .braintrust import VLLMInferenceImpl
# impl = VLLMInferenceImpl(config)
# await impl.initialize()
# return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.eval import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...

View file

@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets, Api.datasets,
Api.inference,
], ],
), ),
] ]

View file

@ -7,3 +7,8 @@ providers:
- provider_id: test-meta - provider_id: test-meta
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def scoring_settings(): async def scoring_settings():
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) impls = await resolve_impls_for_test(
Api.scoring, deps=[Api.datasetio, Api.inference]
)
return { return {
"scoring_impl": impls[Api.scoring], "scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions], "scoring_functions_impl": impls[Api.scoring_functions],
@ -62,8 +64,9 @@ async def test_scoring_score(scoring_settings):
response = await scoring_impl.score_batch( response = await scoring_impl.score_batch(
dataset_id=response[0].identifier, dataset_id=response[0].identifier,
scoring_functions=["equality"], scoring_functions=["equality", "subset_of"],
) )
assert len(response.results) == 1 assert len(response.results) == 2
assert "equality" in response.results assert "equality" in response.results
assert "subset_of" in response.results

View file

@ -33,6 +33,10 @@ providers:
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: http://127.0.0.1:5009 url: http://127.0.0.1:5009
- provider_id: tgi1
provider_type: remote::tgi
config:
url: http://127.0.0.1:5010
memory: memory:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference