refactor schema check

This commit is contained in:
Xi Yan 2024-12-30 17:58:23 -08:00
parent 86b6d41065
commit 41cff917ca

View file

@ -28,7 +28,7 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
ScoringResultRow, ScoringResultRow,
) )
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
@ -194,7 +194,7 @@ class BraintrustScoringImpl(
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow: ) -> ScoringResultRow:
self.validate_row_schema_for_scoring(input_row) self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key() await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]