mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
scoring fix
This commit is contained in:
parent
c5cf9c30be
commit
56239fce90
10 changed files with 104 additions and 15 deletions
|
@ -49,10 +49,14 @@ class Scoring(Protocol):
|
|||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
@webmethod(route="/scoring/score")
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse: ...
|
||||
|
|
|
@ -38,7 +38,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
)
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regex: Optional[List[str]] = Field()
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -212,6 +212,7 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
|
@ -221,6 +222,7 @@ class ScoringRouter(Scoring):
|
|||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
@ -232,7 +234,10 @@ class ScoringRouter(Scoring):
|
|||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
|
@ -242,6 +247,7 @@ class ScoringRouter(Scoring):
|
|||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
|
|
@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
|
@ -98,6 +97,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
|
@ -106,7 +106,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
|
@ -118,14 +120,22 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
scoring_fn_params = None
|
||||
if scoring_params is not None:
|
||||
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
|
|
|
@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
return [
|
||||
await self.score_row(input_row, scoring_fn_identifier)
|
||||
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
|
|
@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
|
|
|
@ -31,6 +31,10 @@ llm_as_judge_8b_correctness = ScoringFnDef(
|
|||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||
judge_score_regexes=[
|
||||
r"Total rating: (\d+)",
|
||||
r"rating: (\d+)",
|
||||
r"Rating: (\d+)",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -36,18 +36,24 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
||||
# override params if scoring_params is provided
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.params.judge_score_regex is not None
|
||||
), "LLM Judge judge_score_regex not found."
|
||||
fn_def.params.judge_score_regexes is not None
|
||||
), "LLM Judge judge_score_regexes not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexs = fn_def.params.judge_score_regex
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
judge_rating = None
|
||||
for regex in rating_regexs:
|
||||
for regex in rating_regexes:
|
||||
match = re.search(regex, content)
|
||||
if match:
|
||||
judge_rating = int(match.group(1))
|
||||
|
|
|
@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
||||
# How to run this test:
|
||||
|
@ -64,3 +66,51 @@ class TestScoring:
|
|||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score_with_params(self, scoring_stack):
|
||||
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
|
||||
scoring_stack
|
||||
)
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
params = {
|
||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
||||
judge_model="Llama3.1-405B-Instruct",
|
||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
)
|
||||
}
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
]
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue