diff --git a/llama_stack/providers/utils/common/data_schema_validator_mixin.py b/llama_stack/providers/utils/common/data_schema_validator_mixin.py index 5736c2be8..54b91f254 100644 --- a/llama_stack/providers/utils/common/data_schema_validator_mixin.py +++ b/llama_stack/providers/utils/common/data_schema_validator_mixin.py @@ -23,67 +23,71 @@ class ColumnName(Enum): context = "context" -def get_expected_schema_for_scoring(): - return [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.generated_answer.value: StringType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.generated_answer.value: StringType(), - ColumnName.context.value: StringType(), - }, - ] - - -def get_expected_schema_for_eval(): - return [ - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.chat_completion_input.value: ChatCompletionInputType(), - }, - { - ColumnName.input_query.value: StringType(), - ColumnName.expected_answer.value: StringType(), - ColumnName.completion_input.value: CompletionInputType(), - }, - ] - - -def validate_dataset_schema( - dataset_schema: Dict[str, Any], expected_schemas: List[Dict[str, Any]] -): - if dataset_schema not in expected_schemas: - raise ValueError( - f"Dataset does not have a correct input schema in {expected_schemas}" - ) - - -def validate_row_schema( - input_row: Dict[str, Any], expected_schemas: List[Dict[str, Any]] -): - for schema in expected_schemas: - if all(key in input_row for key in schema): - return - - raise ValueError( - f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" - ) - - class DataSchemaValidatorMixin: def validate_dataset_schema_for_scoring(self, dataset_schema: Dict[str, Any]): - validate_dataset_schema(dataset_schema, get_expected_schema_for_scoring()) + self.validate_dataset_schema( + dataset_schema, self.get_expected_schema_for_scoring() + ) def validate_dataset_schema_for_eval(self, dataset_schema: Dict[str, Any]): - validate_dataset_schema(dataset_schema, get_expected_schema_for_eval()) + self.validate_dataset_schema( + dataset_schema, self.get_expected_schema_for_eval() + ) def validate_row_schema_for_scoring(self, row_schema: Dict[str, Any]): - validate_row_schema(row_schema, get_expected_schema_for_scoring()) + self.validate_row_schema(row_schema, self.get_expected_schema_for_scoring()) def validate_row_schema_for_eval(self, row_schema: Dict[str, Any]): - validate_row_schema(row_schema, get_expected_schema_for_eval()) + self.validate_row_schema(row_schema, self.get_expected_schema_for_eval()) + + def get_expected_schema_for_scoring(self): + return [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.context.value: StringType(), + }, + ] + + def get_expected_schema_for_eval(self): + return [ + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.chat_completion_input.value: ChatCompletionInputType(), + }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.completion_input.value: CompletionInputType(), + }, + ] + + def validate_dataset_schema( + self, + dataset_schema: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + if dataset_schema not in expected_schemas: + raise ValueError( + f"Dataset does not have a correct input schema in {expected_schemas}" + ) + + def validate_row_schema( + self, + input_row: Dict[str, Any], + expected_schemas: List[Dict[str, Any]], + ): + for schema in expected_schemas: + if all(key in input_row for key in schema): + return + + raise ValueError( + f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}" + )