more scoring function for rag

This commit is contained in:
Xi Yan 2024-12-19 16:53:39 -08:00
parent b94ab8d013
commit 9aa4a405ca
8 changed files with 132 additions and 10 deletions

View file

@ -15,7 +15,12 @@ from llama_stack.apis.datasets import * # noqa: F403
import os import os
from autoevals.llm import Factuality from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
Faithfulness,
)
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
@ -27,7 +32,10 @@ from llama_stack.providers.utils.common.data_schema_validator_mixin import (
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
class BraintrustScoringFnEntry(BaseModel): class BraintrustScoringFnEntry(BaseModel):
@ -47,6 +55,21 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
evaluator=AnswerCorrectness(), evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def, fn_def=answer_correctness_fn_def,
), ),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
] ]
@ -135,6 +158,7 @@ class BraintrustScoringImpl(
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow: ) -> ScoringResultRow:
self.validate_row_schema_for_scoring(input_row)
await self.set_api_key() await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
@ -146,6 +170,7 @@ class BraintrustScoringImpl(
generated_answer, generated_answer,
expected_answer, expected_answer,
input=input_query, input=input_query,
context=input_row["context"] if "context" in input_row else None,
) )
score = result.score score = result.score
return {"score": score, "metadata": result.metadata} return {"score": score, "metadata": result.metadata}

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -37,9 +37,15 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset( async def register_dataset(
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" datasets_impl: Datasets,
for_generation=False,
for_rag=False,
dataset_id="test_dataset",
): ):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file)) test_url = data_url_from_file(str(test_file))
if for_generation: if for_generation:
@ -48,6 +54,13 @@ async def register_dataset(
"input_query": StringType(), "input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(), "chat_completion_input": ChatCompletionInputType(),
} }
elif for_rag:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
"context": StringType(),
}
else: else:
dataset_schema = { dataset_schema = {
"expected_answer": StringType(), "expected_answer": StringType(),

View file

@ -0,0 +1,6 @@
input_query,context,generated_answer,expected_answer
What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris
Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter
What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City
What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen
1 input_query context generated_answer expected_answer
2 What is the capital of France? France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum. London Paris
3 Who is the CEO of Meta? Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies. Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets. Jupiter Jupiter
5 What is the smallest country in the world? Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy. China Vatican City
6 What is the currency of Japan? Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871. Yen Yen

View file

@ -60,7 +60,7 @@ class TestScoring:
f"{provider_id} provider does not support scoring without params" f"{provider_id} provider does not support scoring without params"
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
@ -112,7 +112,7 @@ class TestScoring:
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models], scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
@ -173,7 +173,7 @@ class TestScoring:
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models], scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",
rows_in_page=3, rows_in_page=3,

View file

@ -34,11 +34,11 @@ class DataSchemaValidatorMixin:
dataset_schema, self.get_expected_schema_for_eval() dataset_schema, self.get_expected_schema_for_eval()
) )
def validate_row_schema_for_scoring(self, row_schema: Dict[str, Any]): def validate_row_schema_for_scoring(self, input_row: Dict[str, Any]):
self.validate_row_schema(row_schema, self.get_expected_schema_for_scoring()) self.validate_row_schema(input_row, self.get_expected_schema_for_scoring())
def validate_row_schema_for_eval(self, row_schema: Dict[str, Any]): def validate_row_schema_for_eval(self, input_row: Dict[str, Any]):
self.validate_row_schema(row_schema, self.get_expected_schema_for_eval()) self.validate_row_schema(input_row, self.get_expected_schema_for_eval())
def get_expected_schema_for_scoring(self): def get_expected_schema_for_scoring(self):
return [ return [