This commit is contained in:
Botao Chen 2025-03-19 14:38:14 -07:00
parent 9068416bc4
commit c4c56829ad
5 changed files with 15 additions and 16 deletions

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
@ -37,10 +37,6 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
if scoring_params is not None: if scoring_params is not None:
fn_def.params = scoring_params fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
instruction_list = input_row["instruction_id_list"] instruction_list = input_row["instruction_id_list"]
generated_answer = input_row["generated_answer"].strip() generated_answer = input_row["generated_answer"].strip()
@ -56,7 +52,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
results[instruction_id + "_total"] += 1.0 results[instruction_id + "_total"] += 1.0
results[instruction_id.split(":")[0] + "_total"] += 1.0 results[instruction_id.split(":")[0] + "_total"] += 1.0
instruction.build_description(**input_row["kwargs"][index]) clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
print(clean_input_row)
instruction.build_description(**clean_input_row)
args = instruction.get_instruction_args() args = instruction.get_instruction_args()
if args and "prompt" in args: if args and "prompt" in args:
instruction.build_description(prompt=input_row["prompt"]) instruction.build_description(prompt=input_row["prompt"])

View file

@ -3147,7 +3147,7 @@ class LowercaseLettersEnglishChecker(Instruction):
class CommaChecker(Instruction): class CommaChecker(Instruction):
"""Checks the response for no commas.""" """Checks the response for no commas."""
def build_description(self): def build_description(self, **kwargs):
"""Build the instruction description.""" """Build the instruction description."""
self._description_pattern = "In your entire response, refrain from the use of any commas." self._description_pattern = "In your entire response, refrain from the use of any commas."
return self._description_pattern return self._description_pattern
@ -3216,6 +3216,7 @@ class CapitalWordFrequencyChecker(Instruction):
def check_following(self, value): def check_following(self, value):
"""Checks the frequency of words with all capital letters.""" """Checks the frequency of words with all capital letters."""
# Hyphenated words will count as one word # Hyphenated words will count as one word
nltk.download("punkt_tab")
words = nltk.word_tokenize(value) words = nltk.word_tokenize(value)
capital_words = [word for word in words if word.isupper()] capital_words = [word for word in words if word.isupper()]

View file

@ -57,6 +57,7 @@ def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
AGGREGATION_FUNCTIONS = { AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy, AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average, AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.weighted_average: aggregate_weighted_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count, AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median, AggregationFunctionType.median: aggregate_median,
} }

View file

@ -204,8 +204,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
), ),
DatasetInput( DatasetInput(
dataset_id="IfEval", dataset_id="ifeval",
provider_id="huggingface",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource( source=URIDataSource(
uri="huggingface://datasets/llamastack/IfEval?split=train", uri="huggingface://datasets/llamastack/IfEval?split=train",
@ -240,9 +239,9 @@ def get_distribution_template() -> DistributionTemplate:
scoring_functions=["basic::bfcl"], scoring_functions=["basic::bfcl"],
), ),
BenchmarkInput( BenchmarkInput(
benchmark_id="meta-reference-IfEval", benchmark_id="meta-reference-ifeval",
dataset_id="IfEval", dataset_id="ifeval",
scoring_functions=["basic::IfEval"], scoring_functions=["basic::ifeval"],
), ),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -193,7 +193,7 @@ datasets:
type: uri type: uri
uri: huggingface://datasets/llamastack/IfEval?split=train uri: huggingface://datasets/llamastack/IfEval?split=train
metadata: {} metadata: {}
dataset_id: IfEval dataset_id: ifeval
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- dataset_id: simpleqa - dataset_id: simpleqa
@ -221,11 +221,11 @@ benchmarks:
- basic::bfcl - basic::bfcl
metadata: {} metadata: {}
benchmark_id: meta-reference-bfcl benchmark_id: meta-reference-bfcl
- dataset_id: IfEval - dataset_id: ifeval
scoring_functions: scoring_functions:
- basic::IfEval - basic::ifeval
metadata: {} metadata: {}
benchmark_id: meta-reference-IfEval benchmark_id: meta-reference-ifeval
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search