mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
llm as judge, move folders
This commit is contained in:
parent
bf8bc7a781
commit
16620a8185
13 changed files with 173 additions and 66 deletions
|
@ -2,3 +2,4 @@ include requirements.txt
|
|||
include llama_stack/distribution/*.sh
|
||||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/templates/*/build.yaml
|
||||
include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json
|
||||
|
|
|
@ -97,6 +97,7 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
else:
|
||||
|
|
|
@ -16,6 +16,8 @@ async def get_provider_impl(
|
|||
):
|
||||
from .scoring import MetaReferenceScoringImpl
|
||||
|
||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
||||
impl = MetaReferenceScoringImpl(
|
||||
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
|
||||
EqualityScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
|
||||
LlmAsJudgeScoringFn,
|
||||
)
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
|
||||
SubsetOfScoringFn,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
SUPPORTED_SCORING_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
|
||||
|
||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS}
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
|
@ -37,17 +38,34 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
config: MetaReferenceScoringConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
inference_api: Inference,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.inference_api = inference_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
async def initialize(self) -> None:
|
||||
for x in FIXED_FNS:
|
||||
impl = x()
|
||||
await impl.initialize()
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
for x in LLM_JUDGE_FNS:
|
||||
impl = x(inference_api=self.inference_api)
|
||||
await impl.initialize()
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS]
|
||||
return [
|
||||
fn_defs
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
raise NotImplementedError(
|
||||
|
@ -99,9 +117,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in SCORER_REGISTRY:
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = SCORER_REGISTRY[scoring_fn_id]()
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
|
|
|
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Dict, List
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
import json
|
||||
|
||||
|
||||
class BaseScoringFn(ABC):
|
||||
|
@ -17,14 +18,30 @@ class BaseScoringFn(ABC):
|
|||
- aggregate(self, scoring_fn_results)
|
||||
"""
|
||||
|
||||
scoring_function_def: ScoringFnDef
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {}
|
||||
self.defs_paths = []
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for f in self.defs_paths:
|
||||
with open(f, "r") as f:
|
||||
scoring_fn_def = ScoringFnDef(**json.load(f))
|
||||
self.register_scoring_fn_def(scoring_fn_def)
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||
return [x for x in self.supported_fn_defs_registry.values()]
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(
|
||||
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||
)
|
||||
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
|
@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
|||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
return {
|
||||
"average": sum(
|
||||
result["score"] for result in scoring_results if result["score"] is not None
|
||||
)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
||||
|
|
|
@ -10,8 +10,10 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_
|
|||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
FN_DEFS_PATH,
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,12 +22,9 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
scoring_function_def = ScoringFnDef(
|
||||
identifier="equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.defs_paths = [FN_DEFS_PATH / "equality.json"]
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"identifier": "meta-reference::equality",
|
||||
"description": "Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
"metadata": {},
|
||||
"parameters": [],
|
||||
"return_type": {
|
||||
"type": "number"
|
||||
},
|
||||
"context": null
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"identifier": "meta-reference::llm_as_judge_8b_correctness",
|
||||
"description": "Llm As Judge Scoring Function",
|
||||
"metadata": {},
|
||||
"parameters": [],
|
||||
"return_type": {
|
||||
"type": "number"
|
||||
},
|
||||
"context": {
|
||||
"judge_model": "Llama3.1-8B-Instruct",
|
||||
"prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"identifier": "meta-reference::subset_of",
|
||||
"description": "Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
"metadata": {},
|
||||
"parameters": [],
|
||||
"return_type": {
|
||||
"type": "number"
|
||||
},
|
||||
"context": null
|
||||
}
|
|
@ -3,31 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
import re
|
||||
|
||||
JUDGE_PROMPT = """
|
||||
You will be given a question, a expected_answer, and a system_answer.
|
||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||
Provide your feedback as follows:
|
||||
Feedback:::
|
||||
Total rating: (your rating, as a int between 0 and 5)
|
||||
Now here are the question, expected_answer, system_answer.
|
||||
Question: {question}
|
||||
Expected Answer: {expected_answer}
|
||||
System Answer: {answer}
|
||||
Feedback:::
|
||||
Total rating:
|
||||
"""
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
FN_DEFS_PATH,
|
||||
)
|
||||
|
||||
|
||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||
|
@ -35,27 +23,62 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
A scoring_fn that assigns
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scoring_fn_def_registry = {}
|
||||
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.inference_api = inference_api
|
||||
self.defs_paths = [FN_DEFS_PATH / "llm_as_judge_8b_correctness.json"]
|
||||
|
||||
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
|
||||
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
assert (
|
||||
fn_def.context is not None and fn_def.context.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
|
||||
judge_input_msg = fn_def.context.prompt_template.format(
|
||||
input_query=input_query,
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model=fn_def.context.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexs = [
|
||||
r"Total rating: (\d+)",
|
||||
r"rating: (\d+)",
|
||||
r"Rating: (\d+)",
|
||||
]
|
||||
judge_rating = None
|
||||
for regex in rating_regexs:
|
||||
match = re.search(regex, content)
|
||||
if match:
|
||||
judge_rating = int(match.group(1))
|
||||
break
|
||||
|
||||
return {
|
||||
"score": score,
|
||||
"score": judge_rating,
|
||||
"judge_feedback": content,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
return aggregate_average(scoring_results)
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
FN_DEFS_PATH,
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,23 +21,15 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
scoring_function_def = ScoringFnDef(
|
||||
identifier="subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.defs_paths = [FN_DEFS_PATH / "subset_of.json"]
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||
|
|
|
@ -50,7 +50,9 @@ async def test_scoring_functions_list(scoring_settings):
|
|||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
assert "equality" in function_ids
|
||||
assert "meta-reference::equality" in function_ids
|
||||
assert "meta-reference::subset_of" in function_ids
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -64,9 +66,15 @@ async def test_scoring_score(scoring_settings):
|
|||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=["equality", "subset_of"],
|
||||
scoring_functions=[
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
],
|
||||
)
|
||||
|
||||
assert len(response.results) == 2
|
||||
assert "equality" in response.results
|
||||
assert "subset_of" in response.results
|
||||
print(response)
|
||||
assert len(response.results) == 3
|
||||
assert "meta-reference::equality" in response.results
|
||||
assert "meta-reference::subset_of" in response.results
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue