fix(scoring): remove broken dataset validation in score_batch methods (#4420)

The Dataset model no longer has a dataset_schema attribute it was remove
during a refactor (5287b437a) so this validation can no longer run.

Changes:
o basic scoring: removed validate_dataset_schema call and related
imports o llm_as_judge scoring: removed validate_dataset_schema call and
related imports o braintrust scoring: removed validate_dataset_schema
call and related imports

Validation is no longer needed at the dataset level since: o Dataset
model changed from having dataset_schema to purpose/source fields o
Scoring functions validate required fields when processing rows o
Invalid data will fail naturally with clear error messages

Fixes: #4419

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-12-19 23:52:52 +00:00 committed by GitHub
parent e710622d4c
commit 5ebcde3042
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 0 additions and 20 deletions

View file

@ -5,11 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any
from llama_stack.core.datatypes import Api
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack_api import ( from llama_stack_api import (
DatasetIO, DatasetIO,
Datasets, Datasets,
@ -84,9 +79,6 @@ class BasicScoringImpl(
scoring_functions: dict[str, ScoringFnParams | None] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=-1, limit=-1,

View file

@ -23,7 +23,6 @@ from llama_stack.core.datatypes import Api
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas, get_valid_schemas,
validate_dataset_schema,
validate_row_schema, validate_row_schema,
) )
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
@ -165,9 +164,6 @@ class BraintrustScoringImpl(
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=-1, limit=-1,

View file

@ -5,11 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any
from llama_stack.core.datatypes import Api
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack_api import ( from llama_stack_api import (
DatasetIO, DatasetIO,
Datasets, Datasets,
@ -73,9 +68,6 @@ class LlmAsJudgeScoringImpl(
scoring_functions: dict[str, ScoringFnParams | None] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=-1, limit=-1,