mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
llm as judge, move folders
This commit is contained in:
parent
bf8bc7a781
commit
16620a8185
13 changed files with 173 additions and 66 deletions
|
@ -2,3 +2,4 @@ include requirements.txt
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
include llama_stack/cli/scripts/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_stack/templates/*/build.yaml
|
include llama_stack/templates/*/build.yaml
|
||||||
|
include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json
|
||||||
|
|
|
@ -97,6 +97,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,6 +16,8 @@ async def get_provider_impl(
|
||||||
):
|
):
|
||||||
from .scoring import MetaReferenceScoringImpl
|
from .scoring import MetaReferenceScoringImpl
|
||||||
|
|
||||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
impl = MetaReferenceScoringImpl(
|
||||||
|
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
|
||||||
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||||
EqualityScoringFn,
|
EqualityScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||||
|
LlmAsJudgeScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||||
SubsetOfScoringFn,
|
SubsetOfScoringFn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
|
||||||
SUPPORTED_SCORING_FNS = [
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
||||||
EqualityScoringFn,
|
|
||||||
SubsetOfScoringFn,
|
|
||||||
]
|
|
||||||
|
|
||||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
|
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
@ -37,17 +38,34 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
config: MetaReferenceScoringConfig,
|
config: MetaReferenceScoringConfig,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
|
inference_api: Inference,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.scoring_fn_id_impls = {}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None:
|
||||||
|
for x in FIXED_FNS:
|
||||||
|
impl = x()
|
||||||
|
await impl.initialize()
|
||||||
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
for x in LLM_JUDGE_FNS:
|
||||||
|
impl = x(inference_api=self.inference_api)
|
||||||
|
await impl.initialize()
|
||||||
|
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||||
|
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
|
return [
|
||||||
|
fn_defs
|
||||||
|
for impl in self.scoring_fn_id_impls.values()
|
||||||
|
for fn_defs in impl.get_supported_scoring_fn_defs()
|
||||||
|
]
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -99,9 +117,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in SCORER_REGISTRY:
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||||
agg_results = await scoring_fn.aggregate(score_results)
|
agg_results = await scoring_fn.aggregate(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
|
|
|
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class BaseScoringFn(ABC):
|
class BaseScoringFn(ABC):
|
||||||
|
@ -17,14 +18,30 @@ class BaseScoringFn(ABC):
|
||||||
- aggregate(self, scoring_fn_results)
|
- aggregate(self, scoring_fn_results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def: ScoringFnDef
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {}
|
||||||
|
self.defs_paths = []
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
for f in self.defs_paths:
|
||||||
|
with open(f, "r") as f:
|
||||||
|
scoring_fn_def = ScoringFnDef(**json.load(f))
|
||||||
|
self.register_scoring_fn_def(scoring_fn_def)
|
||||||
|
|
||||||
|
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||||
|
return [x for x in self.supported_fn_defs_registry.values()]
|
||||||
|
|
||||||
|
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||||
|
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||||
|
)
|
||||||
|
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def score_row(
|
async def score_row(
|
||||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
|
|
|
@ -3,10 +3,13 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
|
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||||
|
|
||||||
|
|
||||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
num_correct = sum(result["score"] for result in scoring_results)
|
num_correct = sum(result["score"] for result in scoring_results)
|
||||||
|
@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
||||||
"num_correct": num_correct,
|
"num_correct": num_correct,
|
||||||
"num_total": len(scoring_results),
|
"num_total": len(scoring_results),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"average": sum(
|
||||||
|
result["score"] for result in scoring_results if result["score"] is not None
|
||||||
|
)
|
||||||
|
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||||
|
}
|
||||||
|
|
|
@ -10,8 +10,10 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
aggregate_accuracy,
|
aggregate_accuracy,
|
||||||
|
FN_DEFS_PATH,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,12 +22,9 @@ class EqualityScoringFn(BaseScoringFn):
|
||||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def = ScoringFnDef(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
identifier="equality",
|
super().__init__(*args, **kwargs)
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
self.defs_paths = [FN_DEFS_PATH / "equality.json"]
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"identifier": "meta-reference::equality",
|
||||||
|
"description": "Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
|
"metadata": {},
|
||||||
|
"parameters": [],
|
||||||
|
"return_type": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"context": null
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"identifier": "meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
"description": "Llm As Judge Scoring Function",
|
||||||
|
"metadata": {},
|
||||||
|
"parameters": [],
|
||||||
|
"return_type": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"judge_model": "Llama3.1-8B-Instruct",
|
||||||
|
"prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"identifier": "meta-reference::subset_of",
|
||||||
|
"description": "Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
|
"metadata": {},
|
||||||
|
"parameters": [],
|
||||||
|
"return_type": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"context": null
|
||||||
|
}
|
|
@ -3,31 +3,19 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
BaseScoringFn,
|
BaseScoringFn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
import re
|
||||||
aggregate_accuracy,
|
|
||||||
)
|
|
||||||
|
|
||||||
JUDGE_PROMPT = """
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
You will be given a question, a expected_answer, and a system_answer.
|
aggregate_average,
|
||||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
FN_DEFS_PATH,
|
||||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
)
|
||||||
Provide your feedback as follows:
|
|
||||||
Feedback:::
|
|
||||||
Total rating: (your rating, as a int between 0 and 5)
|
|
||||||
Now here are the question, expected_answer, system_answer.
|
|
||||||
Question: {question}
|
|
||||||
Expected Answer: {expected_answer}
|
|
||||||
System Answer: {answer}
|
|
||||||
Feedback:::
|
|
||||||
Total rating:
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
|
@ -35,27 +23,62 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
A scoring_fn that assigns
|
A scoring_fn that assigns
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*arg, **kwargs)
|
||||||
self.scoring_fn_def_registry = {}
|
self.inference_api = inference_api
|
||||||
|
self.defs_paths = [FN_DEFS_PATH / "llm_as_judge_8b_correctness.json"]
|
||||||
|
|
||||||
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
async def score_row(
|
||||||
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
) -> ScoringResultRow:
|
||||||
assert (
|
assert (
|
||||||
"generated_answer" in input_row
|
scoring_fn_identifier is not None
|
||||||
), "Generated answer not found in input row."
|
), "Scoring function identifier not found."
|
||||||
|
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||||
|
assert (
|
||||||
|
fn_def.context is not None and fn_def.context.prompt_template is not None
|
||||||
|
), "LLM Judge prompt_template not found."
|
||||||
|
|
||||||
|
input_query = input_row["input_query"]
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
|
||||||
|
judge_input_msg = fn_def.context.prompt_template.format(
|
||||||
|
input_query=input_query,
|
||||||
|
expected_answer=expected_answer,
|
||||||
|
generated_answer=generated_answer,
|
||||||
|
)
|
||||||
|
|
||||||
|
judge_response = await self.inference_api.chat_completion(
|
||||||
|
model=fn_def.context.judge_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": judge_input_msg,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
content = judge_response.completion_message.content
|
||||||
|
rating_regexs = [
|
||||||
|
r"Total rating: (\d+)",
|
||||||
|
r"rating: (\d+)",
|
||||||
|
r"Rating: (\d+)",
|
||||||
|
]
|
||||||
|
judge_rating = None
|
||||||
|
for regex in rating_regexs:
|
||||||
|
match = re.search(regex, content)
|
||||||
|
if match:
|
||||||
|
judge_rating = int(match.group(1))
|
||||||
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": score,
|
"score": judge_rating,
|
||||||
|
"judge_feedback": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def aggregate(
|
async def aggregate(
|
||||||
self, scoring_results: List[ScoringResultRow]
|
self, scoring_results: List[ScoringResultRow]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return aggregate_accuracy(scoring_results)
|
return aggregate_average(scoring_results)
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
aggregate_accuracy,
|
aggregate_accuracy,
|
||||||
|
FN_DEFS_PATH,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,23 +21,15 @@ class SubsetOfScoringFn(BaseScoringFn):
|
||||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scoring_function_def = ScoringFnDef(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
identifier="subset_of",
|
super().__init__(*args, **kwargs)
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
self.defs_paths = [FN_DEFS_PATH / "subset_of.json"]
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score_row(
|
async def score_row(
|
||||||
self,
|
self,
|
||||||
input_row: Dict[str, Any],
|
input_row: Dict[str, Any],
|
||||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||||
) -> ScoringResultRow:
|
) -> ScoringResultRow:
|
||||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
|
||||||
assert (
|
|
||||||
"generated_answer" in input_row
|
|
||||||
), "Generated answer not found in input row."
|
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
generated_answer = input_row["generated_answer"]
|
generated_answer = input_row["generated_answer"]
|
||||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||||
|
|
|
@ -50,7 +50,9 @@ async def test_scoring_functions_list(scoring_settings):
|
||||||
assert isinstance(scoring_functions, list)
|
assert isinstance(scoring_functions, list)
|
||||||
assert len(scoring_functions) > 0
|
assert len(scoring_functions) > 0
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
assert "equality" in function_ids
|
assert "meta-reference::equality" in function_ids
|
||||||
|
assert "meta-reference::subset_of" in function_ids
|
||||||
|
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -64,9 +66,15 @@ async def test_scoring_score(scoring_settings):
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
scoring_functions=["equality", "subset_of"],
|
scoring_functions=[
|
||||||
|
"meta-reference::equality",
|
||||||
|
"meta-reference::subset_of",
|
||||||
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.results) == 2
|
print(response)
|
||||||
assert "equality" in response.results
|
assert len(response.results) == 3
|
||||||
assert "subset_of" in response.results
|
assert "meta-reference::equality" in response.results
|
||||||
|
assert "meta-reference::subset_of" in response.results
|
||||||
|
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue