forked from phoenix-oss/llama-stack-mirror
		
	move DataSchemaValidatorMixin into standalone utils (#720)
# What does this PR do? - there's no value in keeping data schema validation logic in a DataSchemaValidatorMixin - move into data schema validation logic into standalone utils ## Test Plan ``` pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio eval/test_eval.py ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
		
							parent
							
								
									0bc5d05243
								
							
						
					
					
						commit
						7a90fc5854
					
				
					 5 changed files with 37 additions and 34 deletions
				
			
		|  | @ -18,8 +18,8 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate | |||
| 
 | ||||
| from llama_stack.providers.utils.common.data_schema_validator import ( | ||||
|     ColumnName, | ||||
|     DataSchemaValidatorMixin, | ||||
|     get_valid_schemas, | ||||
|     validate_dataset_schema, | ||||
| ) | ||||
| from llama_stack.providers.utils.kvstore import kvstore_impl | ||||
| 
 | ||||
|  | @ -31,7 +31,10 @@ from .config import MetaReferenceEvalConfig | |||
| EVAL_TASKS_PREFIX = "eval_tasks:" | ||||
| 
 | ||||
| 
 | ||||
| class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin): | ||||
| class MetaReferenceEvalImpl( | ||||
|     Eval, | ||||
|     EvalTasksProtocolPrivate, | ||||
| ): | ||||
|     def __init__( | ||||
|         self, | ||||
|         config: MetaReferenceEvalConfig, | ||||
|  | @ -85,7 +88,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM | |||
|         candidate = task_config.eval_candidate | ||||
|         scoring_functions = task_def.scoring_functions | ||||
|         dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) | ||||
|         self.validate_dataset_schema( | ||||
|         validate_dataset_schema( | ||||
|             dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) | ||||
|         ) | ||||
|         all_rows = await self.datasetio_api.get_rows_paginated( | ||||
|  |  | |||
|  | @ -18,8 +18,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams | |||
| from llama_stack.distribution.datatypes import Api | ||||
| from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate | ||||
| from llama_stack.providers.utils.common.data_schema_validator import ( | ||||
|     DataSchemaValidatorMixin, | ||||
|     get_valid_schemas, | ||||
|     validate_dataset_schema, | ||||
| ) | ||||
| from .config import BasicScoringConfig | ||||
| from .scoring_fn.equality_scoring_fn import EqualityScoringFn | ||||
|  | @ -30,7 +30,8 @@ FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] | |||
| 
 | ||||
| 
 | ||||
| class BasicScoringImpl( | ||||
|     Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin | ||||
|     Scoring, | ||||
|     ScoringFunctionsProtocolPrivate, | ||||
| ): | ||||
|     def __init__( | ||||
|         self, | ||||
|  | @ -75,7 +76,7 @@ class BasicScoringImpl( | |||
|         save_results_dataset: bool = False, | ||||
|     ) -> ScoreBatchResponse: | ||||
|         dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) | ||||
|         self.validate_dataset_schema( | ||||
|         validate_dataset_schema( | ||||
|             dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -35,8 +35,9 @@ from llama_stack.distribution.datatypes import Api | |||
| from llama_stack.distribution.request_headers import NeedsRequestProviderData | ||||
| from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate | ||||
| from llama_stack.providers.utils.common.data_schema_validator import ( | ||||
|     DataSchemaValidatorMixin, | ||||
|     get_valid_schemas, | ||||
|     validate_dataset_schema, | ||||
|     validate_row_schema, | ||||
| ) | ||||
| 
 | ||||
| from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics | ||||
|  | @ -111,7 +112,6 @@ class BraintrustScoringImpl( | |||
|     Scoring, | ||||
|     ScoringFunctionsProtocolPrivate, | ||||
|     NeedsRequestProviderData, | ||||
|     DataSchemaValidatorMixin, | ||||
| ): | ||||
|     def __init__( | ||||
|         self, | ||||
|  | @ -171,7 +171,7 @@ class BraintrustScoringImpl( | |||
|         await self.set_api_key() | ||||
| 
 | ||||
|         dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) | ||||
|         self.validate_dataset_schema( | ||||
|         validate_dataset_schema( | ||||
|             dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) | ||||
|         ) | ||||
| 
 | ||||
|  | @ -194,7 +194,7 @@ class BraintrustScoringImpl( | |||
|     async def score_row( | ||||
|         self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None | ||||
|     ) -> ScoringResultRow: | ||||
|         self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) | ||||
|         validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) | ||||
|         await self.set_api_key() | ||||
|         assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" | ||||
|         expected_answer = input_row["expected_answer"] | ||||
|  |  | |||
|  | @ -19,8 +19,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams | |||
| from llama_stack.distribution.datatypes import Api | ||||
| from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate | ||||
| from llama_stack.providers.utils.common.data_schema_validator import ( | ||||
|     DataSchemaValidatorMixin, | ||||
|     get_valid_schemas, | ||||
|     validate_dataset_schema, | ||||
| ) | ||||
| 
 | ||||
| from .config import LlmAsJudgeScoringConfig | ||||
|  | @ -31,7 +31,8 @@ LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] | |||
| 
 | ||||
| 
 | ||||
| class LlmAsJudgeScoringImpl( | ||||
|     Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin | ||||
|     Scoring, | ||||
|     ScoringFunctionsProtocolPrivate, | ||||
| ): | ||||
|     def __init__( | ||||
|         self, | ||||
|  | @ -79,7 +80,7 @@ class LlmAsJudgeScoringImpl( | |||
|         save_results_dataset: bool = False, | ||||
|     ) -> ScoreBatchResponse: | ||||
|         dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) | ||||
|         self.validate_dataset_schema( | ||||
|         validate_dataset_schema( | ||||
|             dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue