[Evals API][6/n] meta-reference llm as judge, registration for ScoringFnDefs (#330)

* wip scoring refactor

* llm as judge, move folders

* test full generation + eval

* extract score regex to llm context

* remove prints, cleanup braintrust in this branch

* change json -> class

* remove initialize

* address nits

* check identifier prefix

* udpate MANIFEST
This commit is contained in:
Xi Yan 2024-10-28 14:08:42 -07:00 committed by GitHub
parent 04a4784287
commit 7b8748c53e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 360 additions and 50 deletions

View file

@ -26,6 +26,10 @@ class Parameter(BaseModel):
class LLMAsJudgeContext(BaseModel):
judge_model: str
prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field(
description="Regex to extract the score from the judge response",
default=None,
)
@json_schema_type

View file

@ -18,6 +18,7 @@ from .config import MetaReferenceEvalConfig
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
@ -53,10 +54,12 @@ class MetaReferenceEvalImpl(Eval):
expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},

View file

@ -16,6 +16,8 @@ async def get_provider_impl(
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
impl = MetaReferenceScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
)
await impl.initialize()
return impl

View file

@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
LlmAsJudgeScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn,
)
from .config import MetaReferenceScoringConfig
SUPPORTED_SCORING_FNS = [
EqualityScoringFn,
SubsetOfScoringFn,
]
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
@ -37,22 +38,44 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
inference_api: Inference,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None: ...
async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"meta-reference"
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
raise NotImplementedError(
"Dynamically registering scoring functions is not supported"
)
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
@ -99,11 +122,11 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY:
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
score_results = scoring_fn.score(input_rows)
agg_results = scoring_fn.aggregate(score_results)
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -17,21 +17,41 @@ class BaseScoringFn(ABC):
- aggregate(self, scoring_fn_results)
"""
scoring_function_def: ScoringFnDef
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
@abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
return [self.score_row(input_row) for input_row in input_rows]
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier)
for input_row in input_rows
]

View file

@ -3,10 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
"num_correct": num_correct,
"num_total": len(scoring_results),
}
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return {
"average": sum(
result["score"] for result in scoring_results if result["score"] is not None
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}

View file

@ -10,24 +10,32 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
equality,
)
class EqualityScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFnDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
equality.identifier: equality,
}
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
@ -40,5 +48,7 @@ class EqualityScoringFn(BaseScoringFn):
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
equality = ScoringFnDef(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {input_query}
Expected Answer: {expected_answer}
System Answer: {generated_answer}
Feedback:::
Total rating:
"""
llm_as_judge_8b_correctness = ScoringFnDef(
identifier="meta-reference::llm_as_judge_8b_correctness",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
),
)

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
subset_of = ScoringFnDef(
identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
llm_as_judge_8b_correctness,
)
class LlmAsJudgeScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns
"""
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.inference_api = inference_api
self.supported_fn_defs_registry = {
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
assert (
fn_def.context.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.context.judge_score_regex is not None
), "LLM Judge judge_score_regex not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.context.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.context.judge_model,
messages=[
{
"role": "user",
"content": judge_input_msg,
}
],
)
content = judge_response.completion_message.content
rating_regexs = fn_def.context.judge_score_regex
judge_rating = None
for regex in rating_regexs:
match = re.search(regex, content)
if match:
judge_rating = int(match.group(1))
break
return {
"score": judge_rating,
"judge_feedback": content,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_average(scoring_results)

View file

@ -14,25 +14,27 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
aggregate_accuracy,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
subset_of,
)
class SubsetOfScoringFn(BaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFnDef(
identifier="subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
subset_of.identifier: subset_of,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0
@ -40,5 +42,7 @@ class SubsetOfScoringFn(BaseScoringFn):
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.inference,
],
),
]

View file

@ -70,6 +70,7 @@ async def register_dataset(
if for_generation:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
else:

View file

@ -16,3 +16,7 @@ providers:
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
- provider_id: test-tgi-2
provider_type: remote::tgi
config:
url: http://127.0.0.1:5010

View file

@ -65,7 +65,10 @@ async def test_eval(eval_settings):
model="Llama3.2-1B-Instruct",
sampling_params=SamplingParams(),
),
scoring_functions=["subset_of"],
scoring_functions=[
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id)
@ -76,4 +79,5 @@ async def test_eval(eval_settings):
assert eval_response is not None
assert len(eval_response.generations) == 5
assert "subset_of" in eval_response.scores
assert "meta-reference::subset_of" in eval_response.scores
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores

View file

@ -7,3 +7,8 @@ providers:
- provider_id: test-meta
provider_type: meta-reference
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def scoring_settings():
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
impls = await resolve_impls_for_test(
Api.scoring, deps=[Api.datasetio, Api.inference]
)
return {
"scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions],
@ -48,7 +50,50 @@ async def test_scoring_functions_list(scoring_settings):
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "equality" in function_ids
assert "meta-reference::equality" in function_ids
assert "meta-reference::subset_of" in function_ids
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
@pytest.mark.asyncio
async def test_scoring_functions_register(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
datasets_impl = scoring_settings["datasets_impl"]
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
# register the scoring function
await scoring_functions_impl.register_scoring_function(
ScoringFnDefWithProvider(
identifier="meta-reference::llm_as_judge_8b_random",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=test_prompt,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Number: (\d+)"],
),
provider_id="test-meta",
)
)
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "meta-reference::llm_as_judge_8b_random" in function_ids
# test score using newly registered scoring function
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=[
"meta-reference::llm_as_judge_8b_random",
],
)
assert "meta-reference::llm_as_judge_8b_random" in response.results
@pytest.mark.asyncio
@ -62,8 +107,14 @@ async def test_scoring_score(scoring_settings):
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=["equality"],
scoring_functions=[
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
)
assert len(response.results) == 1
assert "equality" in response.results
assert len(response.results) == 3
assert "meta-reference::equality" in response.results
assert "meta-reference::subset_of" in response.results
assert "meta-reference::llm_as_judge_8b_correctness" in response.results

View file

@ -33,6 +33,10 @@ providers:
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
- provider_id: tgi1
provider_type: remote::tgi
config:
url: http://127.0.0.1:5010
memory:
- provider_id: meta-reference
provider_type: meta-reference