mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
llm judge llamastack scorer
This commit is contained in:
parent
0c4ed66ecc
commit
fa68809a2e
10 changed files with 199 additions and 7 deletions
|
@ -93,7 +93,7 @@ async def run_main(host: str, port: int, eval_dataset_path: str = ""):
|
||||||
)
|
)
|
||||||
cprint(f"datasets/create: {response}", "cyan")
|
cprint(f"datasets/create: {response}", "cyan")
|
||||||
|
|
||||||
# # 2. run evals on the registered dataset
|
# 2. run evals on the registered dataset
|
||||||
eval_task_config = EvaluateTaskConfig(
|
eval_task_config = EvaluateTaskConfig(
|
||||||
dataset_config=EvaluateDatasetConfig(
|
dataset_config=EvaluateDatasetConfig(
|
||||||
dataset_identifier="mmlu-simple-eval-en",
|
dataset_identifier="mmlu-simple-eval-en",
|
||||||
|
@ -151,9 +151,21 @@ async def run_main(host: str, port: int, eval_dataset_path: str = ""):
|
||||||
),
|
),
|
||||||
eval_scoring_config=EvaluateScoringConfig(
|
eval_scoring_config=EvaluateScoringConfig(
|
||||||
scorer_config_list=[
|
scorer_config_list=[
|
||||||
EvaluateSingleScorerConfig(scorer_name="accuracy"),
|
# EvaluateSingleScorerConfig(scorer_name="accuracy"),
|
||||||
|
# EvaluateSingleScorerConfig(
|
||||||
|
# scorer_name="braintrust::answer-correctness"
|
||||||
|
# ),
|
||||||
EvaluateSingleScorerConfig(
|
EvaluateSingleScorerConfig(
|
||||||
scorer_name="braintrust::answer-correctness"
|
scorer_name="llamastack-llm-judge",
|
||||||
|
llm_judge_config=LLMJudgeConfig(
|
||||||
|
judge_processor_config=EvaluateProcessorConfig(
|
||||||
|
processor_identifier="judge",
|
||||||
|
),
|
||||||
|
judge_model_generation_config=EvaluateModelGenerationConfig(
|
||||||
|
model="Llama3.1-8B-Instruct",
|
||||||
|
),
|
||||||
|
judge_scoring_config=EvaluateJudgeScoringConfig(),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
|
|
@ -13,6 +13,7 @@ GeneratorProcessorRegistry = Registry[BaseGeneratorProcessor]()
|
||||||
|
|
||||||
PROCESSOR_REGISTRY = {
|
PROCESSOR_REGISTRY = {
|
||||||
"mmlu": MMLUProcessor,
|
"mmlu": MMLUProcessor,
|
||||||
|
"judge": JudgeProcessor,
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v in PROCESSOR_REGISTRY.items():
|
for k, v in PROCESSOR_REGISTRY.items():
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from llama_stack.apis.evals import * # noqa: F403
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
|
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
|
||||||
from llama_stack.providers.impls.meta_reference.evals.scorer.braintrust_scorer import * # noqa: F403
|
from llama_stack.providers.impls.meta_reference.evals.scorer.braintrust_scorer import * # noqa: F403
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.scorer.llm_judge_scorer import * # noqa: F403
|
||||||
|
|
||||||
from ..registry import Registry
|
from ..registry import Registry
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ ScorerRegistry = Registry[BaseScorer]()
|
||||||
SCORER_REGISTRY = {
|
SCORER_REGISTRY = {
|
||||||
"accuracy": AccuracyScorer,
|
"accuracy": AccuracyScorer,
|
||||||
"random": RandomScorer,
|
"random": RandomScorer,
|
||||||
|
"llamastack-llm-judge": LlamaStackLLMJudgeScorer,
|
||||||
"braintrust::factuality": BraintrustFactualityScorer,
|
"braintrust::factuality": BraintrustFactualityScorer,
|
||||||
"braintrust::answer-correctness": BraintrustAnswerCorrectnessScorer,
|
"braintrust::answer-correctness": BraintrustAnswerCorrectnessScorer,
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,9 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
cprint(f"run_scorer: on {dataset_config} with {eval_scoring_config}", "green")
|
cprint(f"run_scorer: on {dataset_config} with {eval_scoring_config}", "green")
|
||||||
|
|
||||||
run_task = RunScoringTask()
|
run_task = RunScoringTask()
|
||||||
eval_result = await run_task.run(dataset_config, eval_scoring_config)
|
eval_result = await run_task.run(
|
||||||
|
dataset_config, eval_scoring_config, self.inference_api
|
||||||
|
)
|
||||||
|
|
||||||
return EvaluateResponse(
|
return EvaluateResponse(
|
||||||
eval_result=eval_result,
|
eval_result=eval_result,
|
||||||
|
|
|
@ -30,7 +30,6 @@ class InferenceGenerator(BaseGenerator[PreprocessedSample, GenerationResponseSam
|
||||||
) -> List[GenerationResponseSample]:
|
) -> List[GenerationResponseSample]:
|
||||||
generation_outputs = []
|
generation_outputs = []
|
||||||
for sample in preprocessed_dataset:
|
for sample in preprocessed_dataset:
|
||||||
print("generation: ", sample)
|
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=sample.generation_input.messages,
|
messages=sample.generation_input.messages,
|
||||||
|
|
|
@ -3,4 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from .judge_processor import JudgeProcessor # noqa: F401
|
||||||
from .mmlu_processor import MMLUProcessor # noqa: F401
|
from .mmlu_processor import MMLUProcessor # noqa: F401
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import re
|
||||||
|
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
|
||||||
|
JUDGE_PROMPT = """
|
||||||
|
You will be given a question, a expected_answer, and a system_answer.
|
||||||
|
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||||
|
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||||
|
|
||||||
|
Provide your feedback as follows:
|
||||||
|
|
||||||
|
Feedback:::
|
||||||
|
Total rating: (your rating, as a int between 0 and 5)
|
||||||
|
|
||||||
|
Now here are the question, expected_answer, system_answer.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
Expected Answer: {expected_answer}
|
||||||
|
System Answer: {answer}
|
||||||
|
|
||||||
|
Feedback:::
|
||||||
|
Total rating:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class JudgeProcessor(
|
||||||
|
BaseGeneratorProcessor[
|
||||||
|
DictSample, PreprocessedSample, GenerationResponseSample, ScorerInputSample
|
||||||
|
]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generator processor for LLM Judge
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def preprocess_sample(self, sample: DictSample) -> PreprocessedSample:
|
||||||
|
content = JUDGE_PROMPT.format(
|
||||||
|
question=sample.data["input_query"],
|
||||||
|
expected_answer=sample.data["expected_answer"],
|
||||||
|
answer=sample.data["generated_answer"],
|
||||||
|
)
|
||||||
|
preprocessed_msgs = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed_sample = PreprocessedSample(
|
||||||
|
generation_input=GenerationInput(
|
||||||
|
messages=preprocessed_msgs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return processed_sample
|
||||||
|
|
||||||
|
def postprocess_sample(
|
||||||
|
self, generation_sample: GenerationResponseSample, dataset_sample: DictSample
|
||||||
|
) -> ScorerInputSample:
|
||||||
|
response_text = generation_sample.generation_output.completion_message
|
||||||
|
match = re.search(r"Total rating: (\d+)", response_text)
|
||||||
|
judge_rating = int(match.group(1))
|
||||||
|
|
||||||
|
return ScorerInputSample(
|
||||||
|
generated_answer=str(judge_rating),
|
||||||
|
expected_answer=dataset_sample.data["expected_answer"],
|
||||||
|
generation_output=PostprocessedGeneration(
|
||||||
|
completion_message=response_text,
|
||||||
|
),
|
||||||
|
)
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from llama_stack.distribution.registry.generator_processors import (
|
||||||
|
GeneratorProcessorRegistry,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
|
||||||
|
InferenceGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.apis.evals.evals import * # noqa: F401 F403
|
||||||
|
from llama_stack.apis.datasets.datasets import * # noqa: F401 F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackLLMJudgeScorer(BaseScorer[ScorerInputSample]):
|
||||||
|
def __init__(self, llm_judge_config: LLMJudgeConfig, inference_api: Inference):
|
||||||
|
self.llm_judge_config = llm_judge_config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr
|
||||||
|
# We will use another thread wih its own event loop to run the async api within sync function
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
self._thr = threading.Thread(
|
||||||
|
target=self._loop.run_forever, name="Async Runner", daemon=True
|
||||||
|
)
|
||||||
|
if not self._thr.is_alive():
|
||||||
|
self._thr.start()
|
||||||
|
|
||||||
|
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||||
|
input_query = scorer_input_sample.input_query
|
||||||
|
generated_answer = scorer_input_sample.generated_answer
|
||||||
|
expected_answer = scorer_input_sample.expected_answer
|
||||||
|
|
||||||
|
# Judge F1
|
||||||
|
processor = GeneratorProcessorRegistry.get(
|
||||||
|
self.llm_judge_config.judge_processor_config.processor_identifier
|
||||||
|
)()
|
||||||
|
data_sample = DictSample(
|
||||||
|
data={
|
||||||
|
"input_query": input_query,
|
||||||
|
"generated_answer": generated_answer,
|
||||||
|
"expected_answer": expected_answer,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
preprocessed_sample = processor.preprocess_sample(data_sample)
|
||||||
|
|
||||||
|
# Judge Generation
|
||||||
|
generator = InferenceGenerator(
|
||||||
|
model=self.llm_judge_config.judge_model_generation_config.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
generator.generate([preprocessed_sample]), self._loop
|
||||||
|
)
|
||||||
|
generation_outputs = future.result()
|
||||||
|
# Judge F2
|
||||||
|
postprocessed_sample = processor.postprocess_sample(
|
||||||
|
generation_outputs[0], data_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
# Judge F3
|
||||||
|
score = float(postprocessed_sample.generated_answer)
|
||||||
|
|
||||||
|
return SingleEvalResult(score_data={"judge_score": score})
|
||||||
|
|
||||||
|
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||||
|
avg_score = np.average(
|
||||||
|
[result.score_data["judge_score"] for result in eval_results]
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvalResult(
|
||||||
|
metrics={
|
||||||
|
"avg_judge_score": avg_score,
|
||||||
|
}
|
||||||
|
)
|
|
@ -72,7 +72,15 @@ class RunEvalTask(BaseTask):
|
||||||
scorer_list = []
|
scorer_list = []
|
||||||
for s_conf in scorer_config_list:
|
for s_conf in scorer_config_list:
|
||||||
scorer = ScorerRegistry.get(s_conf.scorer_name)
|
scorer = ScorerRegistry.get(s_conf.scorer_name)
|
||||||
scorer_list.append(scorer())
|
if s_conf.llm_judge_config:
|
||||||
|
scorer_list.append(
|
||||||
|
scorer(
|
||||||
|
llm_judge_config=s_conf.llm_judge_config,
|
||||||
|
inference_api=inference_api,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scorer_list.append(scorer())
|
||||||
|
|
||||||
scorer = AggregateScorer(
|
scorer = AggregateScorer(
|
||||||
scorers=scorer_list,
|
scorers=scorer_list,
|
||||||
|
|
|
@ -50,6 +50,7 @@ class RunScoringTask(BaseTask):
|
||||||
self,
|
self,
|
||||||
dataset_config: EvaluateDatasetConfig,
|
dataset_config: EvaluateDatasetConfig,
|
||||||
eval_scoring_config: EvaluateScoringConfig,
|
eval_scoring_config: EvaluateScoringConfig,
|
||||||
|
inference_api: Inference,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> EvalResult:
|
) -> EvalResult:
|
||||||
|
@ -69,7 +70,15 @@ class RunScoringTask(BaseTask):
|
||||||
scorer_list = []
|
scorer_list = []
|
||||||
for s_conf in scorer_config_list:
|
for s_conf in scorer_config_list:
|
||||||
scorer = ScorerRegistry.get(s_conf.scorer_name)
|
scorer = ScorerRegistry.get(s_conf.scorer_name)
|
||||||
scorer_list.append(scorer())
|
if s_conf.llm_judge_config:
|
||||||
|
scorer_list.append(
|
||||||
|
scorer(
|
||||||
|
llm_judge_config=s_conf.llm_judge_config,
|
||||||
|
inference_api=inference_api,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scorer_list.append(scorer())
|
||||||
|
|
||||||
scorer = AggregateScorer(
|
scorer = AggregateScorer(
|
||||||
scorers=scorer_list,
|
scorers=scorer_list,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue