fix optional

This commit is contained in:
Xi Yan 2024-11-07 16:22:33 -08:00
parent fd581c3d88
commit 7ca479f400
3 changed files with 6 additions and 6 deletions

View file

@ -48,7 +48,7 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@ -56,5 +56,5 @@ class Scoring(Protocol):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -212,7 +212,7 @@ class ScoringRouter(Scoring):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
res = {} res = {}
@ -235,7 +235,7 @@ class ScoringRouter(Scoring):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl

View file

@ -96,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -120,7 +120,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions.keys(): for scoring_fn_id in scoring_functions.keys():