From f2ede05f3f53321c895c588d440ec7f072cb962b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 01:01:53 -0700 Subject: [PATCH] update scoring --- llama_stack/apis/scoring/scoring.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 54a9ac2aa..d442d21d7 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -4,10 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from pydantic import BaseModel +from pydantic import BaseModel, Field +from llama_stack.apis.common.job_types import CommonJobFields from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.schema_utils import json_schema_type, webmethod @@ -47,6 +48,18 @@ class ScoreResponse(BaseModel): results: Dict[str, ScoringResult] +@json_schema_type +class ScoringJob(CommonJobFields): + """The ScoringJob object representing a scoring job that was created through API.""" + + type: Literal["scoring"] = "scoring" + # TODO: result files or result datasets ids? + result_files: List[str] = Field( + default_factory=list, + description="Result files of an scoring run. Which can be queried for results.", + ) + + class ScoringFunctionStore(Protocol): def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ... @@ -59,15 +72,14 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]], - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: ... + scoring_functions: List[ScoringFnParams], + ) -> ScoringJob: ... @webmethod(route="/scoring/score", method="POST") async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]], + scoring_functions: List[ScoringFnParams], ) -> ScoreResponse: """Score a list of rows.