move DataSchemaValidatorMixin into standalone utils (#720)

# What does this PR do?

- there's no value in keeping data schema validation logic in a
DataSchemaValidatorMixin
- move into data schema validation logic into standalone utils

## Test Plan
```
pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py

pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio eval/test_eval.py
```



## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-06 13:25:09 -08:00 committed by GitHub
parent 0bc5d05243
commit 7a90fc5854
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 37 additions and 34 deletions

View file

@ -18,8 +18,8 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -31,7 +31,10 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
class MetaReferenceEvalImpl(
Eval,
EvalTasksProtocolPrivate,
):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -85,7 +88,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(

View file

@ -18,8 +18,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
@ -30,7 +30,8 @@ FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
@ -75,7 +76,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)

View file

@ -35,8 +35,9 @@ from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
@ -111,7 +112,6 @@ class BraintrustScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
):
def __init__(
self,
@ -171,7 +171,7 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
@ -194,7 +194,7 @@ class BraintrustScoringImpl(
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]

View file

@ -19,8 +19,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
validate_dataset_schema,
)
from .config import LlmAsJudgeScoringConfig
@ -31,7 +31,8 @@ LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
@ -79,7 +80,7 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)

View file

@ -62,9 +62,7 @@ def get_valid_schemas(api_str: str):
raise ValueError(f"Invalid API string: {api_str}")
class DataSchemaValidatorMixin:
def validate_dataset_schema(
self,
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
@ -73,8 +71,8 @@ class DataSchemaValidatorMixin:
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
self,
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):