refactor schema check

This commit is contained in:
Xi Yan 2024-12-19 16:20:47 -08:00
parent 55e4f4eeb3
commit c15b0d5395
7 changed files with 82 additions and 119 deletions

View file

@ -14,9 +14,10 @@ from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_utils import (
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
ColumnName,
get_expected_schema_for_eval,
DataSchemaValidatorMixin,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -28,7 +29,7 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -72,17 +73,17 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
)
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
# async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
# if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
# raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = get_expected_schema_for_eval()
# expected_schemas = get_expected_schema_for_eval()
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
# if dataset_def.dataset_schema not in expected_schemas:
# raise ValueError(
# f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
# )
async def run_eval(
self,
@ -93,8 +94,8 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_eval(dataset_def.dataset_schema)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(