diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index d257cb8e7..b555c9f2a 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -13,11 +13,13 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring +from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import EvalTasksProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator_mixin import ( +from llama_stack.providers.utils.common.data_schema_validator import ( ColumnName, DataSchemaValidatorMixin, + get_valid_schemas, ) from llama_stack.providers.utils.kvstore import kvstore_impl @@ -83,7 +85,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM candidate = task_config.eval_candidate scoring_functions = task_def.scoring_functions dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - self.validate_dataset_schema_for_eval(dataset_def.dataset_schema) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) + ) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=( diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index ad1f159c8..f612abda4 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -14,11 +14,13 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator_mixin import ( - DataSchemaValidatorMixin, -) +from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.utils.common.data_schema_validator import ( + DataSchemaValidatorMixin, + get_valid_schemas, +) from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn @@ -73,7 +75,9 @@ class BasicScoringImpl( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 7b2b88301..4282ef6ec 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -28,12 +28,15 @@ from llama_stack.apis.scoring import ( ScoringResult, ScoringResultRow, ) -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams + +from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator_mixin import ( +from llama_stack.providers.utils.common.data_schema_validator import ( DataSchemaValidatorMixin, + get_valid_schemas, ) from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics @@ -168,7 +171,9 @@ class BraintrustScoringImpl( await self.set_api_key() dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, @@ -189,7 +194,7 @@ class BraintrustScoringImpl( async def score_row( self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None ) -> ScoringResultRow: - self.validate_row_schema_for_scoring(input_row) + self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) await self.set_api_key() assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" expected_answer = input_row["expected_answer"] diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 3f22f73a9..305c13665 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -16,9 +16,11 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator_mixin import ( +from llama_stack.providers.utils.common.data_schema_validator import ( DataSchemaValidatorMixin, + get_valid_schemas, ) from .config import LlmAsJudgeScoringConfig @@ -77,7 +79,9 @@ class LlmAsJudgeScoringImpl( save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema) + self.validate_dataset_schema( + dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) + ) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py new file mode 100644 index 000000000..d9e6cb6b5 --- /dev/null +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) + +from llama_stack.distribution.datatypes import Api + + +class ColumnName(Enum): + input_query = "input_query" + expected_answer = "expected_answer" + chat_completion_input = "chat_completion_input" + completion_input = "completion_input" + generated_answer = "generated_answer" + context = "context" + + +VALID_SCHEMAS_FOR_SCORING = [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.context.value: StringType(), + }, +] + +VALID_SCHEMAS_FOR_EVAL = [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.completion_input.value: CompletionInputType(), + }, +] + + +def get_valid_schemas(api_str: str): + if api_str == Api.scoring.value: + return VALID_SCHEMAS_FOR_SCORING + elif api_str == Api.eval.value: + return VALID_SCHEMAS_FOR_EVAL + else: + raise ValueError(f"Invalid API string: {api_str}") + + +class DataSchemaValidatorMixin: + def validate_dataset_schema( + self, + dataset_schema: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + if dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}" + ) + + def validate_row_schema( + self, + input_row: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + for schema in expected_schemas: + if all(key in input_row for key in schema): + return + + raise ValueError( + f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" + ) diff --git a/llama_stack/providers/utils/common/data_schema_validator_mixin.py b/llama_stack/providers/utils/common/data_schema_validator_mixin.py deleted file mode 100644 index 19b188fba..000000000 --- a/llama_stack/providers/utils/common/data_schema_validator_mixin.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any, Dict, List - -from llama_stack.apis.common.type_system import ( - ChatCompletionInputType, - CompletionInputType, - StringType, -) - - -class ColumnName(Enum): - input_query = "input_query" - expected_answer = "expected_answer" - chat_completion_input = "chat_completion_input" - completion_input = "completion_input" - generated_answer = "generated_answer" - context = "context" - - -class DataSchemaValidatorMixin: - def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]): - self.validate_dataset_schema( - dataset_schema, self.get_expected_schema_for_scoring() - ) - - def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]): - self.validate_dataset_schema( - dataset_schema, self.get_expected_schema_for_eval() - ) - - def validate_row_schema_for_scoring(self, input_row: Dict[str, Any]): - self.validate_row_schema(input_row, self.get_expected_schema_for_scoring()) - - def validate_row_schema_for_eval(self, input_row: Dict[str, Any]): - self.validate_row_schema(input_row, self.get_expected_schema_for_eval()) - - def get_expected_schema_for_scoring(self): - return [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.generated_answer.value: StringType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.generated_answer.value: StringType(), - ColumnName.context.value: StringType(), - }, - ] - - def get_expected_schema_for_eval(self): - return [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.chat_completion_input.value: ChatCompletionInputType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.completion_input.value: CompletionInputType(), - }, - ] - - def validate_dataset_schema( - self, - dataset_schema: Dict[str, Any], - expected_schemas: List[Dict[str, Any]], - ): - if dataset_schema not in expected_schemas: - raise ValueError( - f"Dataset does not have a correct input schema in {expected_schemas}" - ) - - def validate_row_schema( - self, - input_row: Dict[str, Any], - expected_schemas: List[Dict[str, Any]], - ): - for schema in expected_schemas: - if all(key in input_row for key in schema): - return - - raise ValueError( - f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" - )