mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
refine
This commit is contained in:
parent
9068416bc4
commit
c4c56829ad
5 changed files with 15 additions and 16 deletions
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||||
|
@ -37,10 +37,6 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||||
if scoring_params is not None:
|
if scoring_params is not None:
|
||||||
fn_def.params = scoring_params
|
fn_def.params = scoring_params
|
||||||
|
|
||||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
|
||||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
|
||||||
)
|
|
||||||
|
|
||||||
instruction_list = input_row["instruction_id_list"]
|
instruction_list = input_row["instruction_id_list"]
|
||||||
generated_answer = input_row["generated_answer"].strip()
|
generated_answer = input_row["generated_answer"].strip()
|
||||||
|
|
||||||
|
@ -56,7 +52,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||||
results[instruction_id + "_total"] += 1.0
|
results[instruction_id + "_total"] += 1.0
|
||||||
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||||
|
|
||||||
instruction.build_description(**input_row["kwargs"][index])
|
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||||
|
print(clean_input_row)
|
||||||
|
instruction.build_description(**clean_input_row)
|
||||||
args = instruction.get_instruction_args()
|
args = instruction.get_instruction_args()
|
||||||
if args and "prompt" in args:
|
if args and "prompt" in args:
|
||||||
instruction.build_description(prompt=input_row["prompt"])
|
instruction.build_description(prompt=input_row["prompt"])
|
||||||
|
|
|
@ -3147,7 +3147,7 @@ class LowercaseLettersEnglishChecker(Instruction):
|
||||||
class CommaChecker(Instruction):
|
class CommaChecker(Instruction):
|
||||||
"""Checks the response for no commas."""
|
"""Checks the response for no commas."""
|
||||||
|
|
||||||
def build_description(self):
|
def build_description(self, **kwargs):
|
||||||
"""Build the instruction description."""
|
"""Build the instruction description."""
|
||||||
self._description_pattern = "In your entire response, refrain from the use of any commas."
|
self._description_pattern = "In your entire response, refrain from the use of any commas."
|
||||||
return self._description_pattern
|
return self._description_pattern
|
||||||
|
@ -3216,6 +3216,7 @@ class CapitalWordFrequencyChecker(Instruction):
|
||||||
def check_following(self, value):
|
def check_following(self, value):
|
||||||
"""Checks the frequency of words with all capital letters."""
|
"""Checks the frequency of words with all capital letters."""
|
||||||
# Hyphenated words will count as one word
|
# Hyphenated words will count as one word
|
||||||
|
nltk.download("punkt_tab")
|
||||||
words = nltk.word_tokenize(value)
|
words = nltk.word_tokenize(value)
|
||||||
capital_words = [word for word in words if word.isupper()]
|
capital_words = [word for word in words if word.isupper()]
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,7 @@ def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||||
AGGREGATION_FUNCTIONS = {
|
AGGREGATION_FUNCTIONS = {
|
||||||
AggregationFunctionType.accuracy: aggregate_accuracy,
|
AggregationFunctionType.accuracy: aggregate_accuracy,
|
||||||
AggregationFunctionType.average: aggregate_average,
|
AggregationFunctionType.average: aggregate_average,
|
||||||
|
AggregationFunctionType.weighted_average: aggregate_weighted_average,
|
||||||
AggregationFunctionType.categorical_count: aggregate_categorical_count,
|
AggregationFunctionType.categorical_count: aggregate_categorical_count,
|
||||||
AggregationFunctionType.median: aggregate_median,
|
AggregationFunctionType.median: aggregate_median,
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,8 +204,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="IfEval",
|
dataset_id="ifeval",
|
||||||
provider_id="huggingface",
|
|
||||||
purpose=DatasetPurpose.eval_messages_answer,
|
purpose=DatasetPurpose.eval_messages_answer,
|
||||||
source=URIDataSource(
|
source=URIDataSource(
|
||||||
uri="huggingface://datasets/llamastack/IfEval?split=train",
|
uri="huggingface://datasets/llamastack/IfEval?split=train",
|
||||||
|
@ -240,9 +239,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
scoring_functions=["basic::bfcl"],
|
scoring_functions=["basic::bfcl"],
|
||||||
),
|
),
|
||||||
BenchmarkInput(
|
BenchmarkInput(
|
||||||
benchmark_id="meta-reference-IfEval",
|
benchmark_id="meta-reference-ifeval",
|
||||||
dataset_id="IfEval",
|
dataset_id="ifeval",
|
||||||
scoring_functions=["basic::IfEval"],
|
scoring_functions=["basic::ifeval"],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
|
|
|
@ -193,7 +193,7 @@ datasets:
|
||||||
type: uri
|
type: uri
|
||||||
uri: huggingface://datasets/llamastack/IfEval?split=train
|
uri: huggingface://datasets/llamastack/IfEval?split=train
|
||||||
metadata: {}
|
metadata: {}
|
||||||
dataset_id: IfEval
|
dataset_id: ifeval
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
benchmarks:
|
benchmarks:
|
||||||
- dataset_id: simpleqa
|
- dataset_id: simpleqa
|
||||||
|
@ -221,11 +221,11 @@ benchmarks:
|
||||||
- basic::bfcl
|
- basic::bfcl
|
||||||
metadata: {}
|
metadata: {}
|
||||||
benchmark_id: meta-reference-bfcl
|
benchmark_id: meta-reference-bfcl
|
||||||
- dataset_id: IfEval
|
- dataset_id: ifeval
|
||||||
scoring_functions:
|
scoring_functions:
|
||||||
- basic::IfEval
|
- basic::ifeval
|
||||||
metadata: {}
|
metadata: {}
|
||||||
benchmark_id: meta-reference-IfEval
|
benchmark_id: meta-reference-ifeval
|
||||||
tool_groups:
|
tool_groups:
|
||||||
- toolgroup_id: builtin::websearch
|
- toolgroup_id: builtin::websearch
|
||||||
provider_id: tavily-search
|
provider_id: tavily-search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue