scoring updates

This commit is contained in:
Xi Yan 2025-03-12 21:54:12 -07:00
parent 7b50fdb2b1
commit 3a87562e8d
6 changed files with 1346 additions and 1466 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod

View file

@ -13,7 +13,6 @@ from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -49,27 +48,6 @@ EvalCandidate = register_schema(
)
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
@ -87,32 +65,30 @@ class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
async def evaluate_benchmark(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
candidate: EvalCandidate,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:param candidate: The candidate to evaluate on.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
@webmethod(route="/eval/rows", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
dataset_rows: List[Dict[str, Any]],
scoring_fn_ids: List[str],
candidate: EvalCandidate,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
"""Evaluate a list of rows on a candidate.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:param dataset_rows: The rows to evaluate.
:param scoring_fn_ids: The scoring function ids to use for the evaluation.
:param candidate: The candidate to evaluate on.
:return: EvaluateResponse object containing generations and scores
"""

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
@ -56,23 +56,22 @@ class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST")
async def score_batch(
async def score_dataset(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
scoring_fn_ids: List[str],
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
dataset_rows: List[Dict[str, Any]],
scoring_fn_ids: List[str],
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param dataset_rows: The rows to score.
:param scoring_fn_ids: The scoring function ids to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...

View file

@ -67,7 +67,7 @@ class AggregationFunctionType(Enum):
accuracy = "accuracy"
class BasicScoringFnParamsFields(BaseModel):
class BasicScoringFnParams(BaseModel):
"""
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
"""
@ -78,7 +78,7 @@ class BasicScoringFnParamsFields(BaseModel):
)
class RegexParserScoringFnParamsFields(BaseModel):
class RegexParserScoringFnParams(BaseModel):
"""
:param parsing_regexes: (Optional) Regexes to extract the answer from generated response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed.
@ -93,7 +93,7 @@ class RegexParserScoringFnParamsFields(BaseModel):
default_factory=list,
)
class CustomLLMAsJudgeScoringFnParamsFields(BaseModel):
class CustomLLMAsJudgeScoringFnParams(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
@ -103,103 +103,103 @@ class CustomLLMAsJudgeScoringFnParamsFields(BaseModel):
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
class RegexParserScoringFn(BaseModel):
type: Literal["regex_parser"] = "regex_parser"
regex_parser: RegexParserScoringFnParamsFields
regex_parser: RegexParserScoringFnParams
@json_schema_type
class RegexParserMathScoringFnParams(BaseModel):
class RegexParserMathScoringFn(BaseModel):
type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
regex_parser_math_response: RegexParserScoringFnParamsFields
regex_parser_math_response: RegexParserScoringFnParams
@json_schema_type
class EqualityScoringFnParams(BaseModel):
class EqualityScoringFn(BaseModel):
type: Literal["equality"] = "equality"
equality: BasicScoringFnParamsFields
equality: BasicScoringFnParams
@json_schema_type
class SubsetOfcoringFnParams(BaseModel):
class SubsetOfScoringFn(BaseModel):
type: Literal["subset_of"] = "subset_of"
subset_of: BasicScoringFnParamsFields
subset_of: BasicScoringFnParams
@json_schema_type
class FactualityScoringFnParams(BaseModel):
class FactualityScoringFn(BaseModel):
type: Literal["factuality"] = "factuality"
factuality: BasicScoringFnParamsFields
factuality: BasicScoringFnParams
@json_schema_type
class FaithfulnessScoringFnParams(BaseModel):
class FaithfulnessScoringFn(BaseModel):
type: Literal["faithfulness"] = "faithfulness"
faithfulness: BasicScoringFnParamsFields
faithfulness: BasicScoringFnParams
@json_schema_type
class AnswerCorrectnessScoringFnParams(BaseModel):
class AnswerCorrectnessScoringFn(BaseModel):
type: Literal["answer_correctness"] = "answer_correctness"
answer_correctness: BasicScoringFnParamsFields
answer_correctness: BasicScoringFnParams
@json_schema_type
class AnswerRelevancyScoringFnParams(BaseModel):
class AnswerRelevancyScoringFn(BaseModel):
type: Literal["answer_relevancy"] = "answer_relevancy"
answer_relevancy: BasicScoringFnParamsFields
answer_relevancy: BasicScoringFnParams
@json_schema_type
class AnswerSimilarityScoringFnParams(BaseModel):
class AnswerSimilarityScoringFn(BaseModel):
type: Literal["answer_similarity"] = "answer_similarity"
answer_similarity: BasicScoringFnParamsFields
answer_similarity: BasicScoringFnParams
@json_schema_type
class ContextEntityRecallScoringFnParams(BaseModel):
class ContextEntityRecallScoringFn(BaseModel):
type: Literal["context_entity_recall"] = "context_entity_recall"
context_entity_recall: BasicScoringFnParamsFields
context_entity_recall: BasicScoringFnParams
@json_schema_type
class ContextPrecisionScoringFnParams(BaseModel):
class ContextPrecisionScoringFn(BaseModel):
type: Literal["context_precision"] = "context_precision"
context_precision: BasicScoringFnParamsFields
context_precision: BasicScoringFnParams
@json_schema_type
class ContextRecallScoringFnParams(BaseModel):
class ContextRecallScoringFn(BaseModel):
type: Literal["context_recall"] = "context_recall"
context_recall: BasicScoringFnParamsFields
context_recall: BasicScoringFnParams
@json_schema_type
class ContextRelevancyScoringFnParams(BaseModel):
class ContextRelevancyScoringFn(BaseModel):
type: Literal["context_relevancy"] = "context_relevancy"
context_relevancy: BasicScoringFnParamsFields
context_relevancy: BasicScoringFnParams
@json_schema_type
class CustomLLMAsJudgeScoringFnParams(BaseModel):
class CustomLLMAsJudgeScoringFn(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
custom_llm_as_judge: CustomLLMAsJudgeScoringFnParamsFields
custom_llm_as_judge: CustomLLMAsJudgeScoringFnParams
ScoringFnParams = register_schema(
ScoringFnDefinition = register_schema(
Annotated[
Union[
CustomLLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
RegexParserMathScoringFnParams,
EqualityScoringFnParams,
SubsetOfcoringFnParams,
FactualityScoringFnParams,
FaithfulnessScoringFnParams,
AnswerCorrectnessScoringFnParams,
AnswerRelevancyScoringFnParams,
AnswerSimilarityScoringFnParams,
ContextEntityRecallScoringFnParams,
ContextPrecisionScoringFnParams,
ContextRecallScoringFnParams,
ContextRelevancyScoringFnParams,
CustomLLMAsJudgeScoringFn,
RegexParserScoringFn,
RegexParserMathScoringFn,
EqualityScoringFn,
SubsetOfScoringFn,
FactualityScoringFn,
FaithfulnessScoringFn,
AnswerCorrectnessScoringFn,
AnswerRelevancyScoringFn,
AnswerSimilarityScoringFn,
ContextEntityRecallScoringFn,
ContextPrecisionScoringFn,
ContextRecallScoringFn,
ContextRelevancyScoringFn,
],
Field(discriminator="type"),
],
name="ScoringFnParams",
name="ScoringFnDefinition",
)
@ -208,7 +208,7 @@ class CommonScoringFnFields(BaseModel):
:param fn: The scoring function type and parameters.
:param metadata: (Optional) Any additional metadata for this definition (e.g. description).
"""
fn: ScoringFnParams
fn: ScoringFnDefinition
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition (e.g. description)",
@ -288,7 +288,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(
self,
fn: ScoringFnParams,
fn: ScoringFnDefinition,
scoring_fn_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> ScoringFn: