scoring fix

This commit is contained in:
Xi Yan 2024-11-06 18:07:16 -08:00
parent c5cf9c30be
commit 56239fce90
10 changed files with 104 additions and 15 deletions

View file

@ -49,10 +49,14 @@ class Scoring(Protocol):
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score") @webmethod(route="/scoring/score")
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -38,7 +38,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
) )
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field() judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
@json_schema_type @json_schema_type

View file

@ -212,6 +212,7 @@ class ScoringRouter(Scoring):
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
res = {} res = {}
@ -221,6 +222,7 @@ class ScoringRouter(Scoring):
).score_batch( ).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=[fn_identifier], scoring_functions=[fn_identifier],
scoring_params=scoring_params,
) )
res.update(score_response.results) res.update(score_response.results)
@ -232,7 +234,10 @@ class ScoringRouter(Scoring):
) )
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
@ -242,6 +247,7 @@ class ScoringRouter(Scoring):
).score( ).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions=[fn_identifier], scoring_functions=[fn_identifier],
scoring_params=scoring_params,
) )
res.update(score_response.results) res.update(score_response.results)

View file

@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def) raise NotImplementedError("Register scoring function not implemented yet")
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
@ -98,6 +97,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -106,7 +106,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.score( res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions input_rows=all_rows.rows,
scoring_functions=scoring_functions,
scoring_params=scoring_params,
) )
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
@ -118,14 +120,22 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) )
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions: for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.scoring_fn_id_impls: if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id) scoring_fn_params = None
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results) agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,

View file

@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
@abstractmethod @abstractmethod
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
raise NotImplementedError() raise NotImplementedError()
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]: ) -> List[ScoringResultRow]:
return [ return [
await self.score_row(input_row, scoring_fn_identifier) await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows for input_row in input_rows
] ]

View file

@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality", scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert (

View file

@ -31,6 +31,10 @@ llm_as_judge_8b_correctness = ScoringFnDef(
params=LLMAsJudgeScoringFnParams( params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT, prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct", judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
), ),
) )

View file

@ -36,18 +36,24 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert ( assert (
scoring_fn_identifier is not None scoring_fn_identifier is not None
), "Scoring function identifier not found." ), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert ( assert (
fn_def.params.prompt_template is not None fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found." ), "LLM Judge prompt_template not found."
assert ( assert (
fn_def.params.judge_score_regex is not None fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regex not found." ), "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"] input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
], ],
) )
content = judge_response.completion_message.content content = judge_response.completion_message.content
rating_regexs = fn_def.params.judge_score_regex rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None judge_rating = None
for regex in rating_regexs: for regex in rating_regexes:
match = re.search(regex, content) match = re.search(regex, content)
if match: if match:
judge_rating = int(match.group(1)) judge_rating = int(match.group(1))

View file

@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of", scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]

View file

@ -7,6 +7,8 @@
import pytest import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
# How to run this test: # How to run this test:
@ -64,3 +66,51 @@ class TestScoring:
for x in scoring_functions: for x in scoring_functions:
assert x in response.results assert x in response.results
assert len(response.results[x].score_rows) == 5 assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
params = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],
)
}
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
]
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5