llm judge llamastack scorer

This commit is contained in:
Xi Yan 2024-10-15 13:25:46 -07:00
parent 0c4ed66ecc
commit fa68809a2e
10 changed files with 199 additions and 7 deletions

View file

@ -93,7 +93,7 @@ async def run_main(host: str, port: int, eval_dataset_path: str = ""):
) )
cprint(f"datasets/create: {response}", "cyan") cprint(f"datasets/create: {response}", "cyan")
# # 2. run evals on the registered dataset # 2. run evals on the registered dataset
eval_task_config = EvaluateTaskConfig( eval_task_config = EvaluateTaskConfig(
dataset_config=EvaluateDatasetConfig( dataset_config=EvaluateDatasetConfig(
dataset_identifier="mmlu-simple-eval-en", dataset_identifier="mmlu-simple-eval-en",
@ -151,9 +151,21 @@ async def run_main(host: str, port: int, eval_dataset_path: str = ""):
), ),
eval_scoring_config=EvaluateScoringConfig( eval_scoring_config=EvaluateScoringConfig(
scorer_config_list=[ scorer_config_list=[
EvaluateSingleScorerConfig(scorer_name="accuracy"), # EvaluateSingleScorerConfig(scorer_name="accuracy"),
# EvaluateSingleScorerConfig(
# scorer_name="braintrust::answer-correctness"
# ),
EvaluateSingleScorerConfig( EvaluateSingleScorerConfig(
scorer_name="braintrust::answer-correctness" scorer_name="llamastack-llm-judge",
llm_judge_config=LLMJudgeConfig(
judge_processor_config=EvaluateProcessorConfig(
processor_identifier="judge",
),
judge_model_generation_config=EvaluateModelGenerationConfig(
model="Llama3.1-8B-Instruct",
),
judge_scoring_config=EvaluateJudgeScoringConfig(),
),
), ),
] ]
), ),

View file

@ -13,6 +13,7 @@ GeneratorProcessorRegistry = Registry[BaseGeneratorProcessor]()
PROCESSOR_REGISTRY = { PROCESSOR_REGISTRY = {
"mmlu": MMLUProcessor, "mmlu": MMLUProcessor,
"judge": JudgeProcessor,
} }
for k, v in PROCESSOR_REGISTRY.items(): for k, v in PROCESSOR_REGISTRY.items():

View file

@ -7,6 +7,7 @@
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403 from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.braintrust_scorer import * # noqa: F403 from llama_stack.providers.impls.meta_reference.evals.scorer.braintrust_scorer import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.llm_judge_scorer import * # noqa: F403
from ..registry import Registry from ..registry import Registry
@ -16,6 +17,7 @@ ScorerRegistry = Registry[BaseScorer]()
SCORER_REGISTRY = { SCORER_REGISTRY = {
"accuracy": AccuracyScorer, "accuracy": AccuracyScorer,
"random": RandomScorer, "random": RandomScorer,
"llamastack-llm-judge": LlamaStackLLMJudgeScorer,
"braintrust::factuality": BraintrustFactualityScorer, "braintrust::factuality": BraintrustFactualityScorer,
"braintrust::answer-correctness": BraintrustAnswerCorrectnessScorer, "braintrust::answer-correctness": BraintrustAnswerCorrectnessScorer,
} }

View file

@ -48,7 +48,9 @@ class MetaReferenceEvalsImpl(Evals):
cprint(f"run_scorer: on {dataset_config} with {eval_scoring_config}", "green") cprint(f"run_scorer: on {dataset_config} with {eval_scoring_config}", "green")
run_task = RunScoringTask() run_task = RunScoringTask()
eval_result = await run_task.run(dataset_config, eval_scoring_config) eval_result = await run_task.run(
dataset_config, eval_scoring_config, self.inference_api
)
return EvaluateResponse( return EvaluateResponse(
eval_result=eval_result, eval_result=eval_result,

View file

@ -30,7 +30,6 @@ class InferenceGenerator(BaseGenerator[PreprocessedSample, GenerationResponseSam
) -> List[GenerationResponseSample]: ) -> List[GenerationResponseSample]:
generation_outputs = [] generation_outputs = []
for sample in preprocessed_dataset: for sample in preprocessed_dataset:
print("generation: ", sample)
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model=self.model, model=self.model,
messages=sample.generation_input.messages, messages=sample.generation_input.messages,

View file

@ -3,4 +3,5 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .judge_processor import JudgeProcessor # noqa: F401
from .mmlu_processor import MMLUProcessor # noqa: F401 from .mmlu_processor import MMLUProcessor # noqa: F401

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from llama_stack.apis.evals import * # noqa: F403
JUDGE_PROMPT = """
You will be given a question, a expected_answer, and a system_answer.
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {question}
Expected Answer: {expected_answer}
System Answer: {answer}
Feedback:::
Total rating:
"""
class JudgeProcessor(
BaseGeneratorProcessor[
DictSample, PreprocessedSample, GenerationResponseSample, ScorerInputSample
]
):
"""
Generator processor for LLM Judge
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def preprocess_sample(self, sample: DictSample) -> PreprocessedSample:
content = JUDGE_PROMPT.format(
question=sample.data["input_query"],
expected_answer=sample.data["expected_answer"],
answer=sample.data["generated_answer"],
)
preprocessed_msgs = [
{
"role": "user",
"content": content,
}
]
processed_sample = PreprocessedSample(
generation_input=GenerationInput(
messages=preprocessed_msgs,
)
)
return processed_sample
def postprocess_sample(
self, generation_sample: GenerationResponseSample, dataset_sample: DictSample
) -> ScorerInputSample:
response_text = generation_sample.generation_output.completion_message
match = re.search(r"Total rating: (\d+)", response_text)
judge_rating = int(match.group(1))
return ScorerInputSample(
generated_answer=str(judge_rating),
expected_answer=dataset_sample.data["expected_answer"],
generation_output=PostprocessedGeneration(
completion_message=response_text,
),
)

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import threading
import numpy as np
from llama_stack.distribution.registry.generator_processors import (
GeneratorProcessorRegistry,
)
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
InferenceGenerator,
)
from llama_stack.apis.evals.evals import * # noqa: F401 F403
from llama_stack.apis.datasets.datasets import * # noqa: F401 F403
from llama_stack.apis.inference import * # noqa: F403
class LlamaStackLLMJudgeScorer(BaseScorer[ScorerInputSample]):
def __init__(self, llm_judge_config: LLMJudgeConfig, inference_api: Inference):
self.llm_judge_config = llm_judge_config
self.inference_api = inference_api
# https://stackoverflow.com/questions/74703727/how-to-call-async-function-from-sync-funcion-and-get-result-while-a-loop-is-alr
# We will use another thread wih its own event loop to run the async api within sync function
self._loop = asyncio.new_event_loop()
self._thr = threading.Thread(
target=self._loop.run_forever, name="Async Runner", daemon=True
)
if not self._thr.is_alive():
self._thr.start()
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
input_query = scorer_input_sample.input_query
generated_answer = scorer_input_sample.generated_answer
expected_answer = scorer_input_sample.expected_answer
# Judge F1
processor = GeneratorProcessorRegistry.get(
self.llm_judge_config.judge_processor_config.processor_identifier
)()
data_sample = DictSample(
data={
"input_query": input_query,
"generated_answer": generated_answer,
"expected_answer": expected_answer,
}
)
preprocessed_sample = processor.preprocess_sample(data_sample)
# Judge Generation
generator = InferenceGenerator(
model=self.llm_judge_config.judge_model_generation_config.model,
inference_api=self.inference_api,
)
future = asyncio.run_coroutine_threadsafe(
generator.generate([preprocessed_sample]), self._loop
)
generation_outputs = future.result()
# Judge F2
postprocessed_sample = processor.postprocess_sample(
generation_outputs[0], data_sample
)
# Judge F3
score = float(postprocessed_sample.generated_answer)
return SingleEvalResult(score_data={"judge_score": score})
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
avg_score = np.average(
[result.score_data["judge_score"] for result in eval_results]
)
return EvalResult(
metrics={
"avg_judge_score": avg_score,
}
)

View file

@ -72,7 +72,15 @@ class RunEvalTask(BaseTask):
scorer_list = [] scorer_list = []
for s_conf in scorer_config_list: for s_conf in scorer_config_list:
scorer = ScorerRegistry.get(s_conf.scorer_name) scorer = ScorerRegistry.get(s_conf.scorer_name)
scorer_list.append(scorer()) if s_conf.llm_judge_config:
scorer_list.append(
scorer(
llm_judge_config=s_conf.llm_judge_config,
inference_api=inference_api,
)
)
else:
scorer_list.append(scorer())
scorer = AggregateScorer( scorer = AggregateScorer(
scorers=scorer_list, scorers=scorer_list,

View file

@ -50,6 +50,7 @@ class RunScoringTask(BaseTask):
self, self,
dataset_config: EvaluateDatasetConfig, dataset_config: EvaluateDatasetConfig,
eval_scoring_config: EvaluateScoringConfig, eval_scoring_config: EvaluateScoringConfig,
inference_api: Inference,
*args, *args,
**kwargs, **kwargs,
) -> EvalResult: ) -> EvalResult:
@ -69,7 +70,15 @@ class RunScoringTask(BaseTask):
scorer_list = [] scorer_list = []
for s_conf in scorer_config_list: for s_conf in scorer_config_list:
scorer = ScorerRegistry.get(s_conf.scorer_name) scorer = ScorerRegistry.get(s_conf.scorer_name)
scorer_list.append(scorer()) if s_conf.llm_judge_config:
scorer_list.append(
scorer(
llm_judge_config=s_conf.llm_judge_config,
inference_api=inference_api,
)
)
else:
scorer_list.append(scorer())
scorer = AggregateScorer( scorer = AggregateScorer(
scorers=scorer_list, scorers=scorer_list,