refactor schema check

This commit is contained in:
Xi Yan 2024-12-30 17:53:10 -08:00
parent 3367c52e31
commit eb92322c3c
6 changed files with 115 additions and 104 deletions

View file

@ -13,11 +13,13 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -83,7 +85,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_eval(dataset_def.dataset_schema)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(

View file

@ -14,11 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
DataSchemaValidatorMixin,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -73,7 +75,9 @@ class BasicScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,

View file

@ -21,10 +21,13 @@ from llama_stack.apis.scoring import (
)
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
@ -117,7 +120,9 @@ class BraintrustScoringImpl(
await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,

View file

@ -16,9 +16,11 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import LlmAsJudgeScoringConfig
@ -77,7 +79,9 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema_for_scoring(dataset_def.dataset_schema)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,