scoring updates

This commit is contained in:
Xi Yan 2025-03-12 21:54:12 -07:00
parent 7b50fdb2b1
commit 3a87562e8d
6 changed files with 1346 additions and 1466 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod

View file

@ -13,7 +13,6 @@ from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -49,27 +48,6 @@ EvalCandidate = register_schema(
) )
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type @json_schema_type
class EvaluateResponse(BaseModel): class EvaluateResponse(BaseModel):
"""The response from an evaluation. """The response from an evaluation.
@ -87,32 +65,30 @@ class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates.""" """Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval( async def evaluate_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, candidate: EvalCandidate,
) -> Job: ) -> Job:
"""Run an evaluation on a benchmark. """Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark. :param candidate: The candidate to evaluate on.
:return: The job that was created to run the evaluation. :return: The job that was created to run the evaluation.
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") @webmethod(route="/eval/rows", method="POST")
async def evaluate_rows( async def evaluate_rows(
self, self,
benchmark_id: str, dataset_rows: List[Dict[str, Any]],
input_rows: List[Dict[str, Any]], scoring_fn_ids: List[str],
scoring_functions: List[str], candidate: EvalCandidate,
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark. """Evaluate a list of rows on a candidate.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param dataset_rows: The rows to evaluate.
:param input_rows: The rows to evaluate. :param scoring_fn_ids: The scoring function ids to use for the evaluation.
:param scoring_functions: The scoring functions to use for the evaluation. :param candidate: The candidate to evaluate on.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores :return: EvaluateResponse object containing generations and scores
""" """

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value # mapping of metric to value
@ -56,23 +56,22 @@ class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST") @webmethod(route="/scoring/score-batch", method="POST")
async def score_batch( async def score_dataset(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_fn_ids: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], dataset_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_fn_ids: List[str],
) -> ScoreResponse: ) -> ScoreResponse:
"""Score a list of rows. """Score a list of rows.
:param input_rows: The rows to score. :param dataset_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring. :param scoring_fn_ids: The scoring function ids to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results :return: ScoreResponse object containing rows and aggregated results
""" """
... ...

View file

@ -67,7 +67,7 @@ class AggregationFunctionType(Enum):
accuracy = "accuracy" accuracy = "accuracy"
class BasicScoringFnParamsFields(BaseModel): class BasicScoringFnParams(BaseModel):
""" """
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed. :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
""" """
@ -78,7 +78,7 @@ class BasicScoringFnParamsFields(BaseModel):
) )
class RegexParserScoringFnParamsFields(BaseModel): class RegexParserScoringFnParams(BaseModel):
""" """
:param parsing_regexes: (Optional) Regexes to extract the answer from generated response. :param parsing_regexes: (Optional) Regexes to extract the answer from generated response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed. :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
@ -93,7 +93,7 @@ class RegexParserScoringFnParamsFields(BaseModel):
default_factory=list, default_factory=list,
) )
class CustomLLMAsJudgeScoringFnParamsFields(BaseModel): class CustomLLMAsJudgeScoringFnParams(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
@ -103,103 +103,103 @@ class CustomLLMAsJudgeScoringFnParamsFields(BaseModel):
) )
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFn(BaseModel):
type: Literal["regex_parser"] = "regex_parser" type: Literal["regex_parser"] = "regex_parser"
regex_parser: RegexParserScoringFnParamsFields regex_parser: RegexParserScoringFnParams
@json_schema_type @json_schema_type
class RegexParserMathScoringFnParams(BaseModel): class RegexParserMathScoringFn(BaseModel):
type: Literal["regex_parser_math_response"] = "regex_parser_math_response" type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
regex_parser_math_response: RegexParserScoringFnParamsFields regex_parser_math_response: RegexParserScoringFnParams
@json_schema_type @json_schema_type
class EqualityScoringFnParams(BaseModel): class EqualityScoringFn(BaseModel):
type: Literal["equality"] = "equality" type: Literal["equality"] = "equality"
equality: BasicScoringFnParamsFields equality: BasicScoringFnParams
@json_schema_type @json_schema_type
class SubsetOfcoringFnParams(BaseModel): class SubsetOfScoringFn(BaseModel):
type: Literal["subset_of"] = "subset_of" type: Literal["subset_of"] = "subset_of"
subset_of: BasicScoringFnParamsFields subset_of: BasicScoringFnParams
@json_schema_type @json_schema_type
class FactualityScoringFnParams(BaseModel): class FactualityScoringFn(BaseModel):
type: Literal["factuality"] = "factuality" type: Literal["factuality"] = "factuality"
factuality: BasicScoringFnParamsFields factuality: BasicScoringFnParams
@json_schema_type @json_schema_type
class FaithfulnessScoringFnParams(BaseModel): class FaithfulnessScoringFn(BaseModel):
type: Literal["faithfulness"] = "faithfulness" type: Literal["faithfulness"] = "faithfulness"
faithfulness: BasicScoringFnParamsFields faithfulness: BasicScoringFnParams
@json_schema_type @json_schema_type
class AnswerCorrectnessScoringFnParams(BaseModel): class AnswerCorrectnessScoringFn(BaseModel):
type: Literal["answer_correctness"] = "answer_correctness" type: Literal["answer_correctness"] = "answer_correctness"
answer_correctness: BasicScoringFnParamsFields answer_correctness: BasicScoringFnParams
@json_schema_type @json_schema_type
class AnswerRelevancyScoringFnParams(BaseModel): class AnswerRelevancyScoringFn(BaseModel):
type: Literal["answer_relevancy"] = "answer_relevancy" type: Literal["answer_relevancy"] = "answer_relevancy"
answer_relevancy: BasicScoringFnParamsFields answer_relevancy: BasicScoringFnParams
@json_schema_type @json_schema_type
class AnswerSimilarityScoringFnParams(BaseModel): class AnswerSimilarityScoringFn(BaseModel):
type: Literal["answer_similarity"] = "answer_similarity" type: Literal["answer_similarity"] = "answer_similarity"
answer_similarity: BasicScoringFnParamsFields answer_similarity: BasicScoringFnParams
@json_schema_type @json_schema_type
class ContextEntityRecallScoringFnParams(BaseModel): class ContextEntityRecallScoringFn(BaseModel):
type: Literal["context_entity_recall"] = "context_entity_recall" type: Literal["context_entity_recall"] = "context_entity_recall"
context_entity_recall: BasicScoringFnParamsFields context_entity_recall: BasicScoringFnParams
@json_schema_type @json_schema_type
class ContextPrecisionScoringFnParams(BaseModel): class ContextPrecisionScoringFn(BaseModel):
type: Literal["context_precision"] = "context_precision" type: Literal["context_precision"] = "context_precision"
context_precision: BasicScoringFnParamsFields context_precision: BasicScoringFnParams
@json_schema_type @json_schema_type
class ContextRecallScoringFnParams(BaseModel): class ContextRecallScoringFn(BaseModel):
type: Literal["context_recall"] = "context_recall" type: Literal["context_recall"] = "context_recall"
context_recall: BasicScoringFnParamsFields context_recall: BasicScoringFnParams
@json_schema_type @json_schema_type
class ContextRelevancyScoringFnParams(BaseModel): class ContextRelevancyScoringFn(BaseModel):
type: Literal["context_relevancy"] = "context_relevancy" type: Literal["context_relevancy"] = "context_relevancy"
context_relevancy: BasicScoringFnParamsFields context_relevancy: BasicScoringFnParams
@json_schema_type @json_schema_type
class CustomLLMAsJudgeScoringFnParams(BaseModel): class CustomLLMAsJudgeScoringFn(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
custom_llm_as_judge: CustomLLMAsJudgeScoringFnParamsFields custom_llm_as_judge: CustomLLMAsJudgeScoringFnParams
ScoringFnParams = register_schema( ScoringFnDefinition = register_schema(
Annotated[ Annotated[
Union[ Union[
CustomLLMAsJudgeScoringFnParams, CustomLLMAsJudgeScoringFn,
RegexParserScoringFnParams, RegexParserScoringFn,
RegexParserMathScoringFnParams, RegexParserMathScoringFn,
EqualityScoringFnParams, EqualityScoringFn,
SubsetOfcoringFnParams, SubsetOfScoringFn,
FactualityScoringFnParams, FactualityScoringFn,
FaithfulnessScoringFnParams, FaithfulnessScoringFn,
AnswerCorrectnessScoringFnParams, AnswerCorrectnessScoringFn,
AnswerRelevancyScoringFnParams, AnswerRelevancyScoringFn,
AnswerSimilarityScoringFnParams, AnswerSimilarityScoringFn,
ContextEntityRecallScoringFnParams, ContextEntityRecallScoringFn,
ContextPrecisionScoringFnParams, ContextPrecisionScoringFn,
ContextRecallScoringFnParams, ContextRecallScoringFn,
ContextRelevancyScoringFnParams, ContextRelevancyScoringFn,
], ],
Field(discriminator="type"), Field(discriminator="type"),
], ],
name="ScoringFnParams", name="ScoringFnDefinition",
) )
@ -208,7 +208,7 @@ class CommonScoringFnFields(BaseModel):
:param fn: The scoring function type and parameters. :param fn: The scoring function type and parameters.
:param metadata: (Optional) Any additional metadata for this definition (e.g. description). :param metadata: (Optional) Any additional metadata for this definition (e.g. description).
""" """
fn: ScoringFnParams fn: ScoringFnDefinition
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition (e.g. description)", description="Any additional metadata for this definition (e.g. description)",
@ -288,7 +288,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(
self, self,
fn: ScoringFnParams, fn: ScoringFnDefinition,
scoring_fn_id: Optional[str] = None, scoring_fn_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> ScoringFn: ) -> ScoringFn: