diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a68582057..c2bfdcd23 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,7 +48,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -56,5 +56,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4b28a20d7..7fc65800f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -212,7 +212,7 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} @@ -235,7 +235,7 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index d6eb3ae96..c4add966d 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -96,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -120,7 +120,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys():