diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py index 56e3f9937..a1d5ac365 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType +from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST @@ -37,10 +37,6 @@ class IfEvalScoringFn(RegisteredBaseScoringFn): if scoring_params is not None: fn_def.params = scoring_params - assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, ( - f"RegexParserScoringFnParams not found for {fn_def}." - ) - instruction_list = input_row["instruction_id_list"] generated_answer = input_row["generated_answer"].strip() @@ -56,7 +52,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn): results[instruction_id + "_total"] += 1.0 results[instruction_id.split(":")[0] + "_total"] += 1.0 - instruction.build_description(**input_row["kwargs"][index]) + clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None} + print(clean_input_row) + instruction.build_description(**clean_input_row) args = instruction.get_instruction_args() if args and "prompt" in args: instruction.build_description(prompt=input_row["prompt"]) diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index 344060349..28605159f 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -3147,7 +3147,7 @@ class LowercaseLettersEnglishChecker(Instruction): class CommaChecker(Instruction): """Checks the response for no commas.""" - def build_description(self): + def build_description(self, **kwargs): """Build the instruction description.""" self._description_pattern = "In your entire response, refrain from the use of any commas." return self._description_pattern @@ -3216,6 +3216,7 @@ class CapitalWordFrequencyChecker(Instruction): def check_following(self, value): """Checks the frequency of words with all capital letters.""" # Hyphenated words will count as one word + nltk.download("punkt_tab") words = nltk.word_tokenize(value) capital_words = [word for word in words if word.isupper()] diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 42ae8a9fe..7254c9433 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -57,6 +57,7 @@ def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: AGGREGATION_FUNCTIONS = { AggregationFunctionType.accuracy: aggregate_accuracy, AggregationFunctionType.average: aggregate_average, + AggregationFunctionType.weighted_average: aggregate_weighted_average, AggregationFunctionType.categorical_count: aggregate_categorical_count, AggregationFunctionType.median: aggregate_median, } diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 996f8eec5..448361cb7 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -204,8 +204,7 @@ def get_distribution_template() -> DistributionTemplate: ), ), DatasetInput( - dataset_id="IfEval", - provider_id="huggingface", + dataset_id="ifeval", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/IfEval?split=train", @@ -240,9 +239,9 @@ def get_distribution_template() -> DistributionTemplate: scoring_functions=["basic::bfcl"], ), BenchmarkInput( - benchmark_id="meta-reference-IfEval", - dataset_id="IfEval", - scoring_functions=["basic::IfEval"], + benchmark_id="meta-reference-ifeval", + dataset_id="ifeval", + scoring_functions=["basic::ifeval"], ), ] return DistributionTemplate( diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 4668ed578..ded2d1294 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -193,7 +193,7 @@ datasets: type: uri uri: huggingface://datasets/llamastack/IfEval?split=train metadata: {} - dataset_id: IfEval + dataset_id: ifeval scoring_fns: [] benchmarks: - dataset_id: simpleqa @@ -221,11 +221,11 @@ benchmarks: - basic::bfcl metadata: {} benchmark_id: meta-reference-bfcl -- dataset_id: IfEval +- dataset_id: ifeval scoring_functions: - - basic::IfEval + - basic::ifeval metadata: {} - benchmark_id: meta-reference-IfEval + benchmark_id: meta-reference-ifeval tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search