mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[Evals API][6/n] meta-reference llm as judge, registration for ScoringFnDefs (#330)
* wip scoring refactor * llm as judge, move folders * test full generation + eval * extract score regex to llm context * remove prints, cleanup braintrust in this branch * change json -> class * remove initialize * address nits * check identifier prefix * udpate MANIFEST
This commit is contained in:
parent
04a4784287
commit
7b8748c53e
20 changed files with 360 additions and 50 deletions
|
@ -26,6 +26,10 @@ class Parameter(BaseModel):
|
||||||
class LLMAsJudgeContext(BaseModel):
|
class LLMAsJudgeContext(BaseModel):
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
judge_score_regex: Optional[List[str]] = Field(
|
||||||
|
description="Regex to extract the score from the judge response",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -18,6 +18,7 @@ from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
class ColumnName(Enum):
|
class ColumnName(Enum):
|
||||||
|
input_query = "input_query"
|
||||||
expected_answer = "expected_answer"
|
expected_answer = "expected_answer"
|
||||||
chat_completion_input = "chat_completion_input"
|
chat_completion_input = "chat_completion_input"
|
||||||
completion_input = "completion_input"
|
completion_input = "completion_input"
|
||||||
|
@ -53,10 +54,12 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
|
|
||||||
expected_schemas = [
|
expected_schemas = [
|
||||||
{
|
{
|
||||||
|
ColumnName.input_query.value: StringType(),
|
||||||
ColumnName.expected_answer.value: StringType(),
|
ColumnName.expected_answer.value: StringType(),
|
||||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
ColumnName.input_query.value: StringType(),
|
||||||
ColumnName.expected_answer.value: StringType(),
|
ColumnName.expected_answer.value: StringType(),
|
||||||
ColumnName.completion_input.value: CompletionInputType(),
|
ColumnName.completion_input.value: CompletionInputType(),
|
||||||
},
|
},
|
||||||
|
|
|
@ -16,6 +16,8 @@ async def get_provider_impl(
|
||||||
):
|
):
|
||||||
from .scoring import MetaReferenceScoringImpl
|
from .scoring import MetaReferenceScoringImpl
|
||||||
|
|
||||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
impl = MetaReferenceScoringImpl(
|
||||||
|
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
|
||||||
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||||
EqualityScoringFn,
|
EqualityScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||||
|
LlmAsJudgeScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||||
SubsetOfScoringFn,
|
SubsetOfScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
|
||||||
SUPPORTED_SCORING_FNS = [
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
||||||
EqualityScoringFn,
|
|
||||||
SubsetOfScoringFn,
|
|
||||||
]
|
|
||||||
|
|
||||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
|
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
@ -37,22 +38,44 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
config: MetaReferenceScoringConfig,
|
config: MetaReferenceScoringConfig,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
|
inference_api: Inference,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.scoring_fn_id_impls = {}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None:
|
||||||
|
for x in FIXED_FNS:
|
||||||
|
impl = x()
|
||||||
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
for x in LLM_JUDGE_FNS:
|
||||||
|
impl = x(inference_api=self.inference_api)
|
||||||
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
self.llm_as_judge_fn = impl
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
|
scoring_fn_defs_list = [
|
||||||
|
fn_def
|
||||||
|
for impl in self.scoring_fn_id_impls.values()
|
||||||
|
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||||
|
]
|
||||||
|
|
||||||
|
for f in scoring_fn_defs_list:
|
||||||
|
assert f.identifier.startswith(
|
||||||
|
"meta-reference"
|
||||||
|
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
|
||||||
|
|
||||||
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
raise NotImplementedError(
|
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||||
"Dynamically registering scoring functions is not supported"
|
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
@ -99,11 +122,11 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in SCORER_REGISTRY:
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
score_results = scoring_fn.score(input_rows)
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||||
agg_results = scoring_fn.aggregate(score_results)
|
agg_results = await scoring_fn.aggregate(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -17,21 +17,41 @@ class BaseScoringFn(ABC):
|
||||||
- aggregate(self, scoring_fn_results)
|
- aggregate(self, scoring_fn_results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def: ScoringFnDef
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {}
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||||
|
return [x for x in self.supported_fn_defs_registry.values()]
|
||||||
|
|
||||||
|
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||||
|
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||||
|
)
|
||||||
|
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
async def score_row(
|
||||||
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
|
) -> ScoringResultRow:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
|
async def score(
|
||||||
return [self.score_row(input_row) for input_row in input_rows]
|
self,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
) -> List[ScoringResultRow]:
|
||||||
|
return [
|
||||||
|
await self.score_row(input_row, scoring_fn_identifier)
|
||||||
|
for input_row in input_rows
|
||||||
|
]
|
||||||
|
|
|
@ -3,10 +3,13 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
|
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||||
|
|
||||||
|
|
||||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
num_correct = sum(result["score"] for result in scoring_results)
|
num_correct = sum(result["score"] for result in scoring_results)
|
||||||
|
@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
||||||
"num_correct": num_correct,
|
"num_correct": num_correct,
|
||||||
"num_total": len(scoring_results),
|
"num_total": len(scoring_results),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"average": sum(
|
||||||
|
result["score"] for result in scoring_results if result["score"] is not None
|
||||||
|
)
|
||||||
|
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||||
|
}
|
||||||
|
|
|
@ -10,24 +10,32 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
aggregate_accuracy,
|
aggregate_accuracy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||||
|
equality,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EqualityScoringFn(BaseScoringFn):
|
class EqualityScoringFn(BaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def = ScoringFnDef(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
identifier="equality",
|
super().__init__(*args, **kwargs)
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
self.supported_fn_defs_registry = {
|
||||||
parameters=[],
|
equality.identifier: equality,
|
||||||
return_type=NumberType(),
|
}
|
||||||
)
|
|
||||||
|
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "equality",
|
||||||
|
) -> ScoringResultRow:
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||||
assert (
|
assert (
|
||||||
"generated_answer" in input_row
|
"generated_answer" in input_row
|
||||||
|
@ -40,5 +48,7 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
|
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
return aggregate_accuracy(scoring_results)
|
return aggregate_accuracy(scoring_results)
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
|
equality = ScoringFnDef(
|
||||||
|
identifier="meta-reference::equality",
|
||||||
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
)
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
|
||||||
|
JUDGE_PROMPT = """
|
||||||
|
You will be given a question, a expected_answer, and a system_answer.
|
||||||
|
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||||
|
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||||
|
Provide your feedback as follows:
|
||||||
|
Feedback:::
|
||||||
|
Total rating: (your rating, as a int between 0 and 5)
|
||||||
|
Now here are the question, expected_answer, system_answer.
|
||||||
|
Question: {input_query}
|
||||||
|
Expected Answer: {expected_answer}
|
||||||
|
System Answer: {generated_answer}
|
||||||
|
Feedback:::
|
||||||
|
Total rating:
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||||
|
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
description="Llm As Judge Scoring Function",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
context=LLMAsJudgeContext(
|
||||||
|
prompt_template=JUDGE_PROMPT,
|
||||||
|
judge_model="Llama3.1-8B-Instruct",
|
||||||
|
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
|
subset_of = ScoringFnDef(
|
||||||
|
identifier="meta-reference::subset_of",
|
||||||
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
)
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
|
BaseScoringFn,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
import re
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_average,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||||
|
llm_as_judge_8b_correctness,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn that assigns
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||||
|
super().__init__(*arg, **kwargs)
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert (
|
||||||
|
scoring_fn_identifier is not None
|
||||||
|
), "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
|
||||||
|
assert (
|
||||||
|
fn_def.context.prompt_template is not None
|
||||||
|
), "LLM Judge prompt_template not found."
|
||||||
|
assert (
|
||||||
|
fn_def.context.judge_score_regex is not None
|
||||||
|
), "LLM Judge judge_score_regex not found."
|
||||||
|
|
||||||
|
input_query = input_row["input_query"]
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
|
||||||
|
judge_input_msg = fn_def.context.prompt_template.format(
|
||||||
|
input_query=input_query,
|
||||||
|
expected_answer=expected_answer,
|
||||||
|
generated_answer=generated_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
judge_response = await self.inference_api.chat_completion(
|
||||||
|
model=fn_def.context.judge_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": judge_input_msg,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
content = judge_response.completion_message.content
|
||||||
|
rating_regexs = fn_def.context.judge_score_regex
|
||||||
|
|
||||||
|
judge_rating = None
|
||||||
|
for regex in rating_regexs:
|
||||||
|
match = re.search(regex, content)
|
||||||
|
if match:
|
||||||
|
judge_rating = int(match.group(1))
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"score": judge_rating,
|
||||||
|
"judge_feedback": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return aggregate_average(scoring_results)
|
|
@ -14,25 +14,27 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import
|
||||||
aggregate_accuracy,
|
aggregate_accuracy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||||
|
subset_of,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SubsetOfScoringFn(BaseScoringFn):
|
class SubsetOfScoringFn(BaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def = ScoringFnDef(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
identifier="subset_of",
|
super().__init__(*args, **kwargs)
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
self.supported_fn_defs_registry = {
|
||||||
parameters=[],
|
subset_of.identifier: subset_of,
|
||||||
return_type=NumberType(),
|
}
|
||||||
)
|
|
||||||
|
|
||||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
|
||||||
assert (
|
|
||||||
"generated_answer" in input_row
|
|
||||||
), "Generated answer not found in input row."
|
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||||
|
) -> ScoringResultRow:
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||||
|
@ -40,5 +42,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
"score": score,
|
"score": score,
|
||||||
}
|
}
|
||||||
|
|
||||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
return aggregate_accuracy(scoring_results)
|
return aggregate_accuracy(scoring_results)
|
||||||
|
|
|
@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
|
Api.inference,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -70,6 +70,7 @@ async def register_dataset(
|
||||||
if for_generation:
|
if for_generation:
|
||||||
dataset_schema = {
|
dataset_schema = {
|
||||||
"expected_answer": StringType(),
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
"chat_completion_input": ChatCompletionInputType(),
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,3 +16,7 @@ providers:
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:5009
|
url: http://127.0.0.1:5009
|
||||||
|
- provider_id: test-tgi-2
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5010
|
||||||
|
|
|
@ -65,7 +65,10 @@ async def test_eval(eval_settings):
|
||||||
model="Llama3.2-1B-Instruct",
|
model="Llama3.2-1B-Instruct",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
),
|
),
|
||||||
scoring_functions=["subset_of"],
|
scoring_functions=[
|
||||||
|
"meta-reference::subset_of",
|
||||||
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
assert response.job_id == "0"
|
assert response.job_id == "0"
|
||||||
job_status = await eval_impl.job_status(response.job_id)
|
job_status = await eval_impl.job_status(response.job_id)
|
||||||
|
@ -76,4 +79,5 @@ async def test_eval(eval_settings):
|
||||||
|
|
||||||
assert eval_response is not None
|
assert eval_response is not None
|
||||||
assert len(eval_response.generations) == 5
|
assert len(eval_response.generations) == 5
|
||||||
assert "subset_of" in eval_response.scores
|
assert "meta-reference::subset_of" in eval_response.scores
|
||||||
|
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores
|
||||||
|
|
|
@ -7,3 +7,8 @@ providers:
|
||||||
- provider_id: test-meta
|
- provider_id: test-meta
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
inference:
|
||||||
|
- provider_id: tgi0
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5009
|
||||||
|
|
|
@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def scoring_settings():
|
async def scoring_settings():
|
||||||
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.scoring, deps=[Api.datasetio, Api.inference]
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"scoring_impl": impls[Api.scoring],
|
"scoring_impl": impls[Api.scoring],
|
||||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||||
|
@ -48,7 +50,50 @@ async def test_scoring_functions_list(scoring_settings):
|
||||||
assert isinstance(scoring_functions, list)
|
assert isinstance(scoring_functions, list)
|
||||||
assert len(scoring_functions) > 0
|
assert len(scoring_functions) > 0
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
assert "equality" in function_ids
|
assert "meta-reference::equality" in function_ids
|
||||||
|
assert "meta-reference::subset_of" in function_ids
|
||||||
|
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scoring_functions_register(scoring_settings):
|
||||||
|
scoring_impl = scoring_settings["scoring_impl"]
|
||||||
|
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||||
|
datasets_impl = scoring_settings["datasets_impl"]
|
||||||
|
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
||||||
|
# register the scoring function
|
||||||
|
await scoring_functions_impl.register_scoring_function(
|
||||||
|
ScoringFnDefWithProvider(
|
||||||
|
identifier="meta-reference::llm_as_judge_8b_random",
|
||||||
|
description="Llm As Judge Scoring Function",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
context=LLMAsJudgeContext(
|
||||||
|
prompt_template=test_prompt,
|
||||||
|
judge_model="Llama3.1-8B-Instruct",
|
||||||
|
judge_score_regex=[r"Number: (\d+)"],
|
||||||
|
),
|
||||||
|
provider_id="test-meta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||||
|
assert isinstance(scoring_functions, list)
|
||||||
|
assert len(scoring_functions) > 0
|
||||||
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
|
assert "meta-reference::llm_as_judge_8b_random" in function_ids
|
||||||
|
|
||||||
|
# test score using newly registered scoring function
|
||||||
|
await register_dataset(datasets_impl)
|
||||||
|
response = await datasets_impl.list_datasets()
|
||||||
|
assert len(response) == 1
|
||||||
|
response = await scoring_impl.score_batch(
|
||||||
|
dataset_id=response[0].identifier,
|
||||||
|
scoring_functions=[
|
||||||
|
"meta-reference::llm_as_judge_8b_random",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "meta-reference::llm_as_judge_8b_random" in response.results
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -62,8 +107,14 @@ async def test_scoring_score(scoring_settings):
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
scoring_functions=["equality"],
|
scoring_functions=[
|
||||||
|
"meta-reference::equality",
|
||||||
|
"meta-reference::subset_of",
|
||||||
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.results) == 1
|
assert len(response.results) == 3
|
||||||
assert "equality" in response.results
|
assert "meta-reference::equality" in response.results
|
||||||
|
assert "meta-reference::subset_of" in response.results
|
||||||
|
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
||||||
|
|
|
@ -33,6 +33,10 @@ providers:
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:5009
|
url: http://127.0.0.1:5009
|
||||||
|
- provider_id: tgi1
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:5010
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue