mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
[rag evals][2/n] add more braintrust scoring fns for RAG eval (#666)
# What does this PR do? - add more braintrust scoring functions for RAG eval - add tests for evaluating against context ## Test Plan ``` pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py ``` <img width="850" alt="image" src="https://github.com/user-attachments/assets/2f8f0693-ea13-422c-a183-f798faf86433" /> **Example Output** - https://gist.github.com/yanxi0830/2acf3b8b3e8132fda2a48b1f0a49711b <img width="827" alt="image" src="https://github.com/user-attachments/assets/9014b957-107c-4c23-bbc0-812cbd0b16da" /> <img width="436" alt="image" src="https://github.com/user-attachments/assets/21e9da17-f426-49b2-9113-855cab7b3d40" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
eb92322c3c
commit
2da455f48e
12 changed files with 276 additions and 12 deletions
|
@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.agents import Agents, StepType
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
|
@ -139,11 +139,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
|
|||
)
|
||||
]
|
||||
final_event = turn_response[-1].event.payload
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: final_event.turn.output_message.content
|
||||
}
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
memory_rag_context = None
|
||||
for step in final_event.turn.steps:
|
||||
if step.step_type == StepType.memory_retrieval.value:
|
||||
memory_rag_context = " ".join(x.text for x in step.inserted_context)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = (
|
||||
final_event.turn.output_message.content
|
||||
)
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
generations.append(agent_generation)
|
||||
|
||||
return generations
|
||||
|
||||
|
|
|
@ -7,7 +7,16 @@ import os
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from autoevals.ragas import (
|
||||
AnswerCorrectness,
|
||||
AnswerRelevancy,
|
||||
AnswerSimilarity,
|
||||
ContextEntityRecall,
|
||||
ContextPrecision,
|
||||
ContextRecall,
|
||||
ContextRelevancy,
|
||||
Faithfulness,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -19,7 +28,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringResult,
|
||||
ScoringResultRow,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -33,7 +42,14 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
from .config import BraintrustScoringConfig
|
||||
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
|
||||
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
|
||||
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
|
||||
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
|
||||
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
|
||||
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
|
||||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
|
||||
|
||||
|
||||
class BraintrustScoringFnEntry(BaseModel):
|
||||
|
@ -53,6 +69,41 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
|
|||
evaluator=AnswerCorrectness(),
|
||||
fn_def=answer_correctness_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::answer-relevancy",
|
||||
evaluator=AnswerRelevancy(),
|
||||
fn_def=answer_relevancy_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::answer-similarity",
|
||||
evaluator=AnswerSimilarity(),
|
||||
fn_def=answer_similarity_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::faithfulness",
|
||||
evaluator=Faithfulness(),
|
||||
fn_def=faithfulness_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-entity-recall",
|
||||
evaluator=ContextEntityRecall(),
|
||||
fn_def=context_entity_recall_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-precision",
|
||||
evaluator=ContextPrecision(),
|
||||
fn_def=context_precision_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-recall",
|
||||
evaluator=ContextRecall(),
|
||||
fn_def=context_recall_fn_def,
|
||||
),
|
||||
BraintrustScoringFnEntry(
|
||||
identifier="braintrust::context-relevancy",
|
||||
evaluator=ContextRelevancy(),
|
||||
fn_def=context_relevancy_fn_def,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
@ -143,6 +194,7 @@ class BraintrustScoringImpl(
|
|||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
@ -154,6 +206,7 @@ class BraintrustScoringImpl(
|
|||
generated_answer,
|
||||
expected_answer,
|
||||
input=input_query,
|
||||
context=input_row["context"] if "context" in input_row else None,
|
||||
)
|
||||
score = result.score
|
||||
return {"score": score, "metadata": result.metadata}
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
answer_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-relevancy",
|
||||
description=(
|
||||
"Test output relevancy against the input query using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
answer_similarity_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-similarity",
|
||||
description=(
|
||||
"Test output similarity against expected value using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-similarity",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_entity_recall_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-entity-recall",
|
||||
description=(
|
||||
"Evaluates how well the context captures the named entities present in the "
|
||||
"reference answer. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-entity-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_precision_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-precision",
|
||||
description=(
|
||||
"Measures how much of the provided context is actually relevant to answering the "
|
||||
"question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-precision",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_recall_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-recall",
|
||||
description=(
|
||||
"Evaluates how well the context covers the information needed to answer the "
|
||||
"question. See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-recall",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
context_relevancy_fn_def = ScoringFn(
|
||||
identifier="braintrust::context-relevancy",
|
||||
description=(
|
||||
"Assesses how relevant the provided context is to the given question. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="context-relevancy",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
faithfulness_fn_def = ScoringFn(
|
||||
identifier="braintrust::faithfulness",
|
||||
description=(
|
||||
"Test output faithfulness to the input query using Braintrust LLM scorer. "
|
||||
"See: github.com/braintrustdata/autoevals"
|
||||
),
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="faithfulness",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.average]
|
||||
),
|
||||
)
|
|
@ -38,8 +38,14 @@ def data_url_from_file(file_path: str) -> str:
|
|||
|
||||
|
||||
async def register_dataset(
|
||||
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||
datasets_impl: Datasets,
|
||||
for_generation=False,
|
||||
for_rag=False,
|
||||
dataset_id="test_dataset",
|
||||
):
|
||||
if for_rag:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
||||
else:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
|
@ -49,6 +55,13 @@ async def register_dataset(
|
|||
"input_query": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
}
|
||||
elif for_rag:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
"context": StringType(),
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
input_query,context,generated_answer,expected_answer
|
||||
What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris
|
||||
Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg
|
||||
What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter
|
||||
What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City
|
||||
What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen
|
|
|
@ -60,7 +60,7 @@ class TestScoring:
|
|||
f"{provider_id} provider does not support scoring without params"
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
|
@ -112,7 +112,7 @@ class TestScoring:
|
|||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
|
@ -173,7 +173,7 @@ class TestScoring:
|
|||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue