diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 7fe4d5c97..29f4a106f 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -14,9 +14,10 @@ from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import EvalTasksProtocolPrivate -from llama_stack.providers.utils.common.data_schema_utils import ( + +from llama_stack.providers.utils.common.data_schema_validator_mixin import ( ColumnName, - get_expected_schema_for_eval, + DataSchemaValidatorMixin, ) from llama_stack.providers.utils.kvstore import kvstore_impl @@ -28,7 +29,7 @@ from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" -class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): +class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin): def __init__( self, config: MetaReferenceEvalConfig, @@ -72,17 +73,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): ) self.eval_tasks[task_def.identifier] = task_def - async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") + # async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: + # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + # if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + # raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") - expected_schemas = get_expected_schema_for_eval() + # expected_schemas = get_expected_schema_for_eval() - if dataset_def.dataset_schema not in expected_schemas: - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" - ) + # if dataset_def.dataset_schema not in expected_schemas: + # raise ValueError( + # f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" + # ) async def run_eval( self, @@ -93,8 +94,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): dataset_id = task_def.dataset_id candidate = task_config.eval_candidate scoring_functions = task_def.scoring_functions - - await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema_for_eval(dataset_def.dataset_schema) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=( diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 0c0503ff5..921ad4b26 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -12,6 +12,9 @@ from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator_mixin import ( + DataSchemaValidatorMixin, +) from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn @@ -21,7 +24,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] -class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class BasicScoringImpl( + Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin +): def __init__( self, config: BasicScoringConfig, @@ -58,30 +63,15 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def register_scoring_function(self, function_def: ScoringFn) -> None: raise NotImplementedError("Register scoring function not implemented yet") - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) - async def score_batch( self, dataset_id: str, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index e2e28ff06..53e3e7ea6 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -15,19 +15,18 @@ from llama_stack.apis.datasets import * # noqa: F403 import os from autoevals.llm import Factuality -from autoevals.ragas import AnswerCorrectness, AnswerRelevancy +from autoevals.ragas import AnswerCorrectness from pydantic import BaseModel from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_utils import ( - get_expected_schema_for_scoring, +from llama_stack.providers.utils.common.data_schema_validator_mixin import ( + DataSchemaValidatorMixin, ) from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def -from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def @@ -48,16 +47,14 @@ SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [ evaluator=AnswerCorrectness(), fn_def=answer_correctness_fn_def, ), - BraintrustScoringFnEntry( - identifier="braintrust::answer-relevancy", - evaluator=AnswerRelevancy(), - fn_def=answer_relevancy_fn_def, - ), ] class BraintrustScoringImpl( - Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData + Scoring, + ScoringFunctionsProtocolPrivate, + NeedsRequestProviderData, + DataSchemaValidatorMixin, ): def __init__( self, @@ -96,32 +93,6 @@ class BraintrustScoringImpl( "Registering scoring function not allowed for braintrust provider" ) - async def validate_scoring_input_row_schema( - self, input_row: Dict[str, Any] - ) -> None: - expected_schemas = get_expected_schema_for_scoring() - for schema in expected_schemas: - if all(key in input_row for key in schema): - return - - raise ValueError( - f"Input row {input_row} does not match any of the expected schemas" - ) - - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - expected_schemas = get_expected_schema_for_scoring() - - if dataset_def.dataset_schema not in expected_schemas: - raise ValueError( - f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" - ) - async def set_api_key(self) -> None: # api key is in the request headers if not self.config.openai_api_key: @@ -141,7 +112,10 @@ class BraintrustScoringImpl( save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.set_api_key() - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, @@ -172,7 +146,6 @@ class BraintrustScoringImpl( generated_answer, expected_answer, input=input_query, - context=input_row["context"] if "context" in input_row else None, ) score = result.score return {"score": score, "metadata": result.metadata} diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py deleted file mode 100644 index c37a5dd68..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - - -answer_relevancy_fn_def = ScoringFn( - identifier="braintrust::answer-relevancy", - description=( - "Scores answer relevancy according to the question" - "Uses Braintrust LLM-based scorer from autoevals library." - ), - provider_id="braintrust", - provider_resource_id="answer-relevancy", - return_type=NumberType(), - params=BasicScoringFnParams( - aggregation_functions=[AggregationFunctionType.average] - ), -) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 09780e6fb..3f22f73a9 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -17,6 +17,9 @@ from llama_stack.apis.scoring import ( ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator_mixin import ( + DataSchemaValidatorMixin, +) from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn @@ -25,7 +28,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] -class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class LlmAsJudgeScoringImpl( + Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin +): def __init__( self, config: LlmAsJudgeScoringConfig, @@ -65,30 +70,15 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def register_scoring_function(self, function_def: ScoringFn) -> None: raise NotImplementedError("Register scoring function not implemented yet") - async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: - raise ValueError( - f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." - ) - - for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column." - ) - if dataset_def.dataset_schema[required_column].type != "string": - raise ValueError( - f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." - ) - async def score_batch( self, dataset_id: str, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 38da74128..d6794d488 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -7,8 +7,7 @@ import pytest -from llama_models.llama3.api import SamplingParams, URL - +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( @@ -16,6 +15,7 @@ from llama_stack.apis.eval.eval import ( BenchmarkEvalTaskConfig, ModelCandidate, ) +from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset diff --git a/llama_stack/providers/utils/common/data_schema_utils.py b/llama_stack/providers/utils/common/data_schema_validator_mixin.py similarity index 55% rename from llama_stack/providers/utils/common/data_schema_utils.py rename to llama_stack/providers/utils/common/data_schema_validator_mixin.py index 1a0e4c989..5736c2be8 100644 --- a/llama_stack/providers/utils/common/data_schema_utils.py +++ b/llama_stack/providers/utils/common/data_schema_validator_mixin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from enum import Enum +from typing import Any, Dict, List from llama_stack.apis.common.type_system import ( ChatCompletionInputType, @@ -51,3 +52,38 @@ def get_expected_schema_for_eval(): ColumnName.completion_input.value: CompletionInputType(), }, ] + + +def validate_dataset_schema( + dataset_schema: Dict[str, Any], expected_schemas: List[Dict[str, Any]] +): + if dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset does not have a correct input schema in {expected_schemas}" + ) + + +def validate_row_schema( + input_row: Dict[str, Any], expected_schemas: List[Dict[str, Any]] +): + for schema in expected_schemas: + if all(key in input_row for key in schema): + return + + raise ValueError( + f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" + ) + + +class DataSchemaValidatorMixin: + def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]): + validate_dataset_schema(dataset_schema, get_expected_schema_for_scoring()) + + def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]): + validate_dataset_schema(dataset_schema, get_expected_schema_for_eval()) + + def validate_row_schema_for_scoring(self, row_schema: Dict[str, Any]): + validate_row_schema(row_schema, get_expected_schema_for_scoring()) + + def validate_row_schema_for_eval(self, row_schema: Dict[str, Any]): + validate_row_schema(row_schema, get_expected_schema_for_eval())