scoring fix

This commit is contained in:
Xi Yan 2024-11-06 18:07:16 -08:00
parent c5cf9c30be
commit 56239fce90
10 changed files with 104 additions and 15 deletions

View file

@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
@ -98,6 +97,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -106,7 +106,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
scoring_params=scoring_params,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -118,14 +120,22 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
scoring_fn_params = None
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
@abstractmethod
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier)
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows
]

View file

@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (

View file

@ -31,6 +31,10 @@ llm_as_judge_8b_correctness = ScoringFnDef(
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
),
)

View file

@ -36,18 +36,24 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.params.judge_score_regex is not None
), "LLM Judge judge_score_regex not found."
fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
],
)
content = judge_response.completion_message.content
rating_regexs = fn_def.params.judge_score_regex
rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None
for regex in rating_regexs:
for regex in rating_regexes:
match = re.search(regex, content)
if match:
judge_rating = int(match.group(1))

View file

@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]