update scoring

This commit is contained in:
Xi Yan 2025-03-12 01:01:53 -07:00
parent fdf251234e
commit f2ede05f3f

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.apis.common.job_types import CommonJobFields
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
@ -47,6 +48,18 @@ class ScoreResponse(BaseModel):
results: Dict[str, ScoringResult] results: Dict[str, ScoringResult]
@json_schema_type
class ScoringJob(CommonJobFields):
"""The ScoringJob object representing a scoring job that was created through API."""
type: Literal["scoring"] = "scoring"
# TODO: result files or result datasets ids?
result_files: List[str] = Field(
default_factory=list,
description="Result files of an scoring run. Which can be queried for results.",
)
class ScoringFunctionStore(Protocol): class ScoringFunctionStore(Protocol):
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ... def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@ -59,15 +72,14 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_functions: List[ScoringFnParams],
save_results_dataset: bool = False, ) -> ScoringJob: ...
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_functions: List[ScoringFnParams],
) -> ScoreResponse: ) -> ScoreResponse:
"""Score a list of rows. """Score a list of rows.