From 55e4f4eeb3aa002a4fb4c0e30fb47f6f875cee4f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Dec 2024 15:50:15 -0800 Subject: [PATCH] refactor schema check --- .../inline/eval/meta_reference/eval.py | 36 +++------ .../inline/scoring/braintrust/braintrust.py | 75 +++++++++++++++---- .../scoring_fn/fn_defs/answer_correctness.py | 5 +- .../scoring_fn/fn_defs/answer_relevancy.py | 27 +++++++ .../scoring_fn/fn_defs/factuality.py | 5 +- .../providers/utils/common/__init__.py | 5 ++ .../utils/common/data_schema_utils.py | 53 +++++++++++++ 7 files changed, 162 insertions(+), 44 deletions(-) create mode 100644 llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py create mode 100644 llama_stack/providers/utils/common/__init__.py create mode 100644 llama_stack/providers/utils/common/data_schema_utils.py diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 453215e41..7fe4d5c97 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -3,36 +3,31 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 + from tqdm import tqdm -from .....apis.common.job_types import Job -from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask -from llama_stack.apis.inference import Inference +from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate +from llama_stack.providers.utils.common.data_schema_utils import ( + ColumnName, + get_expected_schema_for_eval, +) from llama_stack.providers.utils.kvstore import kvstore_impl +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus + from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" -class ColumnName(Enum): - input_query = "input_query" - expected_answer = "expected_answer" - chat_completion_input = "chat_completion_input" - completion_input = "completion_input" - generated_answer = "generated_answer" - - class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): def __init__( self, @@ -82,18 +77,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - expected_schemas = [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.chat_completion_input.value: ChatCompletionInputType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.completion_input.value: CompletionInputType(), - }, - ] + expected_schemas = get_expected_schema_for_eval() if dataset_def.dataset_schema not in expected_schemas: raise ValueError( diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 7a966e67a..e2e28ff06 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -15,18 +15,47 @@ from llama_stack.apis.datasets import * # noqa: F403 import os from autoevals.llm import Factuality -from autoevals.ragas import AnswerCorrectness +from autoevals.ragas import AnswerCorrectness, AnswerRelevancy +from pydantic import BaseModel from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_utils import ( + get_expected_schema_for_scoring, +) from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics - from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def +from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def +class BraintrustScoringFnEntry(BaseModel): + identifier: str + evaluator: Any + fn_def: ScoringFn + + +SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [ + BraintrustScoringFnEntry( + identifier="braintrust::factuality", + evaluator=Factuality(), + fn_def=factuality_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::answer-correctness", + evaluator=AnswerCorrectness(), + fn_def=answer_correctness_fn_def, + ), + BraintrustScoringFnEntry( + identifier="braintrust::answer-relevancy", + evaluator=AnswerRelevancy(), + fn_def=answer_relevancy_fn_def, + ), +] + + class BraintrustScoringImpl( Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData ): @@ -41,12 +70,12 @@ class BraintrustScoringImpl( self.datasets_api = datasets_api self.braintrust_evaluators = { - "braintrust::factuality": Factuality(), - "braintrust::answer-correctness": AnswerCorrectness(), + entry.identifier: entry.evaluator + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } self.supported_fn_defs_registry = { - factuality_fn_def.identifier: factuality_fn_def, - answer_correctness_fn_def.identifier: answer_correctness_fn_def, + entry.identifier: entry.fn_def + for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY } async def initialize(self) -> None: ... @@ -67,6 +96,18 @@ class BraintrustScoringImpl( "Registering scoring function not allowed for braintrust provider" ) + async def validate_scoring_input_row_schema( + self, input_row: Dict[str, Any] + ) -> None: + expected_schemas = get_expected_schema_for_scoring() + for schema in expected_schemas: + if all(key in input_row for key in schema): + return + + raise ValueError( + f"Input row {input_row} does not match any of the expected schemas" + ) + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: @@ -74,15 +115,12 @@ class BraintrustScoringImpl( f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." ) - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) + expected_schemas = get_expected_schema_for_scoring() + + if dataset_def.dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" + ) async def set_api_key(self) -> None: # api key is in the request headers @@ -130,7 +168,12 @@ class BraintrustScoringImpl( input_query = input_row["input_query"] evaluator = self.braintrust_evaluators[scoring_fn_identifier] - result = evaluator(generated_answer, expected_answer, input=input_query) + result = evaluator( + generated_answer, + expected_answer, + input=input_query, + context=input_row["context"] if "context" in input_row else None, + ) score = result.score return {"score": score, "metadata": result.metadata} diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py index becadf97b..526ba2c37 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py @@ -14,7 +14,10 @@ from llama_stack.apis.scoring_functions import ( answer_correctness_fn_def = ScoringFn( identifier="braintrust::answer-correctness", - description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", + description=( + "Scores the correctness of the answer based on the ground truth. " + "Uses Braintrust LLM-based scorer from autoevals library." + ), provider_id="braintrust", provider_resource_id="answer-correctness", return_type=NumberType(), diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py new file mode 100644 index 000000000..c37a5dd68 --- /dev/null +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + + +answer_relevancy_fn_def = ScoringFn( + identifier="braintrust::answer-relevancy", + description=( + "Scores answer relevancy according to the question" + "Uses Braintrust LLM-based scorer from autoevals library." + ), + provider_id="braintrust", + provider_resource_id="answer-relevancy", + return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), +) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py index 88cd7b3a7..a4d597c29 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py @@ -14,7 +14,10 @@ from llama_stack.apis.scoring_functions import ( factuality_fn_def = ScoringFn( identifier="braintrust::factuality", - description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", + description=( + "Test output factuality against expected value using Braintrust LLM scorer. " + "See: github.com/braintrustdata/autoevals" + ), provider_id="braintrust", provider_resource_id="factuality", return_type=NumberType(), diff --git a/llama_stack/providers/utils/common/__init__.py b/llama_stack/providers/utils/common/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/common/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/common/data_schema_utils.py b/llama_stack/providers/utils/common/data_schema_utils.py new file mode 100644 index 000000000..1a0e4c989 --- /dev/null +++ b/llama_stack/providers/utils/common/data_schema_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) + + +class ColumnName(Enum): + input_query = "input_query" + expected_answer = "expected_answer" + chat_completion_input = "chat_completion_input" + completion_input = "completion_input" + generated_answer = "generated_answer" + context = "context" + + +def get_expected_schema_for_scoring(): + return [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.context.value: StringType(), + }, + ] + + +def get_expected_schema_for_eval(): + return [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.completion_input.value: CompletionInputType(), + }, + ]