precommit local run

This commit is contained in:
Yang Yang 2025-03-12 13:47:36 -07:00
parent 32d025923c
commit 3c560609dd
6 changed files with 71 additions and 109 deletions

View file

@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, UserMessage, SystemMessage from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value]) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x['role'] == 'user'] input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
@ -168,11 +168,11 @@ class MetaReferenceEvalImpl(
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x['role'] == 'user'] input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x['role'] == 'system'] messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages messages += input_messages
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=candidate.model, model_id=candidate.model,

View file

@ -22,14 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
) )
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
FIXED_FNS = [
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn] EqualityScoringFn,
SubsetOfScoringFn,
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
]
class BasicScoringImpl( class BasicScoringImpl(

View file

@ -4,18 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.bfcl import bfcl
from ..utils.bfcl.ast_parser import decode_ast from ..utils.bfcl.ast_parser import decode_ast
from ..utils.bfcl.checker import ast_checker, is_empty_output from ..utils.bfcl.checker import ast_checker, is_empty_output
import json from .fn_defs.bfcl import bfcl
import re
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
@ -62,11 +61,7 @@ def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
# If `test_category` is "irrelevance", the model is expected to output no function call. # If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call. # If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
acc = ( acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
not x["contain_func_call"]
if "irrelevance" in x["id"]
else x["contain_func_call"]
)
return {"valid": float(acc)} return {"valid": float(acc)}
@ -87,12 +82,12 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = "bfcl", scoring_fn_identifier: Optional[str] = "bfcl",
scoring_params: Optional[ScoringFnParams] = None, scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
test_category = re.sub(r'_[0-9_-]+$', '', input_row['id']) test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category) score_result = postprocess(input_row, test_category)
if (test_category in {'irrelevance', 'live_relevance', 'live_irrelevance'}): if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
score = gen_relevance_acc(score_result)['valid'] score = gen_relevance_acc(score_result)["valid"]
else: else:
score = gen_valid(score_result)['valid'] score = gen_valid(score_result)["valid"]
return { return {
"score": float(score), "score": float(score),
} }

View file

@ -1,4 +1,4 @@
#ruff: noqa # ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -71,11 +71,7 @@ def parse_java_function_call(source_code):
return f"new {type_text}()" return f"new {type_text}()"
elif node.type == "set": elif node.type == "set":
# Handling sets specifically # Handling sets specifically
items = [ items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
traverse_node(n, True)
for n in node.children
if n.type not in [",", "set"]
]
return "{" + ", ".join(items) + "}" return "{" + ", ".join(items) + "}"
elif node.child_count > 0: elif node.child_count > 0:
@ -124,9 +120,7 @@ def parse_java_function_call(source_code):
arguments = extract_arguments(arguments_node) arguments = extract_arguments(arguments_node)
for key, value in arguments.items(): for key, value in arguments.items():
if isinstance(value, list): if isinstance(value, list):
raise Exception( raise Exception("Error: Multiple arguments with the same name are not supported.")
"Error: Multiple arguments with the same name are not supported."
)
return [{function_name: arguments}] return [{function_name: arguments}]
else: else:
@ -157,9 +151,7 @@ def parse_javascript_function_call(source_code):
# Extract left (name) and right (value) parts of the assignment # Extract left (name) and right (value) parts of the assignment
name = child.children[0].text.decode("utf-8") name = child.children[0].text.decode("utf-8")
value = child.children[2].text.decode("utf-8") value = child.children[2].text.decode("utf-8")
if (value.startswith('"') and value.endswith('"')) or ( if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
value.startswith("'") and value.endswith("'")
):
value = value[1:-1] # Trim the quotation marks value = value[1:-1] # Trim the quotation marks
if name in args: if name in args:
if not isinstance(args[name], list): if not isinstance(args[name], list):
@ -190,9 +182,7 @@ def parse_javascript_function_call(source_code):
parameters = extract_arguments(arguments_node) parameters = extract_arguments(arguments_node)
for key, value in parameters.items(): for key, value in parameters.items():
if isinstance(value, list): if isinstance(value, list):
raise Exception( raise Exception("Error: Multiple arguments with the same name are not supported.")
"Error: Multiple arguments with the same name are not supported."
)
result = [{function_name: parameters}] result = [{function_name: parameters}]
return result return result
@ -209,9 +199,7 @@ def ast_parse(input_str, language="Python"):
extracted.append(resolve_ast_call(elem)) extracted.append(resolve_ast_call(elem))
return extracted return extracted
elif language == "Java": elif language == "Java":
return parse_java_function_call( return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
input_str[1:-1]
) # Remove the [ and ] from the string
elif language == "JavaScript": elif language == "JavaScript":
return parse_javascript_function_call(input_str[1:-1]) return parse_javascript_function_call(input_str[1:-1])
else: else:
@ -254,17 +242,10 @@ def resolve_ast_by_type(value):
elif isinstance(value, ast.List): elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts] output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict): elif isinstance(value, ast.Dict):
output = { output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
resolve_ast_by_type(k): resolve_ast_by_type(v) elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value, ast.NameConstant
): # Added this condition to handle boolean values
output = value.value output = value.value
elif isinstance( elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
value, ast.BinOp
): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value)) output = eval(ast.unparse(value))
elif isinstance(value, ast.Name): elif isinstance(value, ast.Name):
output = value.id output = value.id
@ -311,7 +292,5 @@ def decode_execute(result):
execution_list = [] execution_list = []
for function_call in decode_output: for function_call in decode_output:
for key, value in function_call.items(): for key, value in function_call.items():
execution_list.append( execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
)
return execution_list return execution_list

View file

@ -1,4 +1,4 @@
#ruff: noqa # ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -220,9 +220,7 @@ def list_checker(param: str, model_output: list, possible_answer: list):
standardize_possible_answer.append([]) standardize_possible_answer.append([])
for j in range(len(possible_answer[i])): for j in range(len(possible_answer[i])):
if type(possible_answer[i][j]) == str: if type(possible_answer[i][j]) == str:
standardize_possible_answer[i].append( standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
standardize_string(possible_answer[i][j])
)
else: else:
standardize_possible_answer[i].append(possible_answer[i][j]) standardize_possible_answer[i].append(possible_answer[i][j])
@ -244,7 +242,6 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
for i in range(len(possible_answers)): for i in range(len(possible_answers)):
if possible_answers[i] == "": if possible_answers[i] == "":
continue continue
@ -272,9 +269,7 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
standardize_possible_answer = [] standardize_possible_answer = []
for i in range(len(possible_answer[key])): for i in range(len(possible_answer[key])):
if type(possible_answer[key][i]) == str: if type(possible_answer[key][i]) == str:
standardize_possible_answer.append( standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
standardize_string(possible_answer[key][i])
)
else: else:
standardize_possible_answer.append(possible_answer[key][i]) standardize_possible_answer.append(possible_answer[key][i])
@ -353,7 +348,6 @@ def simple_function_checker(
"error_type": "simple_function_checker:unclear", "error_type": "simple_function_checker:unclear",
} }
# Check if function name matches # Check if function name matches
if func_name not in model_output: if func_name not in model_output:
result["valid"] = False result["valid"] = False
@ -403,9 +397,7 @@ def simple_function_checker(
if expected_type_description in NESTED_CONVERSION_TYPE_LIST: if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"] nested_type = param_details[param]["items"]["type"]
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type] nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
value = java_type_converter( value = java_type_converter(value, expected_type_description, nested_type)
value, expected_type_description, nested_type
)
else: else:
value = java_type_converter(value, expected_type_description) value = java_type_converter(value, expected_type_description)
@ -426,9 +418,7 @@ def simple_function_checker(
if expected_type_description in NESTED_CONVERSION_TYPE_LIST: if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"] nested_type = param_details[param]["items"]["type"]
nested_type_converted = JS_TYPE_CONVERSION[nested_type] nested_type_converted = JS_TYPE_CONVERSION[nested_type]
value = js_type_converter( value = js_type_converter(value, expected_type_description, nested_type)
value, expected_type_description, nested_type
)
else: else:
value = js_type_converter(value, expected_type_description) value = js_type_converter(value, expected_type_description)
@ -445,11 +435,7 @@ def simple_function_checker(
value = list(value) value = list(value)
# Allow python auto conversion from int to float # Allow python auto conversion from int to float
if ( if language == "Python" and expected_type_description == "float" and type(value) == int:
language == "Python"
and expected_type_description == "float"
and type(value) == int
):
value = float(value) value = float(value)
# Type checking # Type checking
@ -609,9 +595,7 @@ def parallel_function_checker_no_order(
) )
if not result["valid"]: if not result["valid"]:
considered_indices = [ considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
i for i in range(len(model_output)) if i not in matched_indices
]
all_errors.insert( all_errors.insert(
0, 0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
@ -782,9 +766,7 @@ def executable_checker_simple(
else: else:
# structural match # structural match
pattern_match_result = patten_matcher( pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
exec_output, expected_result, function_call, is_sanity_check
)
if not pattern_match_result["valid"]: if not pattern_match_result["valid"]:
return pattern_match_result return pattern_match_result
@ -794,7 +776,6 @@ def executable_checker_simple(
def executable_checker_parallel_no_order( def executable_checker_parallel_no_order(
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
): ):
if len(decoded_result) != len(expected_exec_result): if len(decoded_result) != len(expected_exec_result):
return { return {
"valid": False, "valid": False,
@ -828,18 +809,14 @@ def executable_checker_parallel_no_order(
"sub_error": result["error"], "sub_error": result["error"],
"sub_error_type": result["error_type"], "sub_error_type": result["error_type"],
"model_executed_output": ( "model_executed_output": (
result["model_executed_output"] result["model_executed_output"] if "model_executed_output" in result else None
if "model_executed_output" in result
else None
), ),
} }
} }
) )
if not result["valid"]: if not result["valid"]:
considered_indices = [ considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
i for i in range(len(decoded_result)) if i not in matched_indices
]
all_errors.insert( all_errors.insert(
0, 0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
@ -874,7 +851,6 @@ def executable_checker_rest(func_call, idx):
try: try:
if response.status_code == 200: if response.status_code == 200:
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx]) eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
try: try:
if isinstance(eval_GT_json, dict): if isinstance(eval_GT_json, dict):
@ -888,9 +864,7 @@ def executable_checker_rest(func_call, idx):
} }
return { return {
"valid": False, "valid": False,
"error": [ "error": [f"Expected dictionary, but got {type(response.json())}"],
f"Expected dictionary, but got {type(response.json())}"
],
"error_type": "executable_checker_rest:wrong_type", "error_type": "executable_checker_rest:wrong_type",
} }
@ -905,9 +879,7 @@ def executable_checker_rest(func_call, idx):
else: else:
for i in range(len(eval_GT_json)): for i in range(len(eval_GT_json)):
if set(eval_GT_json[i].keys()) != set( if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
response.json()[i].keys()
):
return { return {
"valid": False, "valid": False,
"error": [f"Key inconsistency"], "error": [f"Key inconsistency"],
@ -918,16 +890,12 @@ def executable_checker_rest(func_call, idx):
else: else:
return { return {
"valid": False, "valid": False,
"error": [ "error": [f"Expected list, but got {type(response.json())}"],
f"Expected list, but got {type(response.json())}"
],
"error_type": "executable_checker_rest:wrong_type", "error_type": "executable_checker_rest:wrong_type",
} }
return { return {
"valid": False, "valid": False,
"error": [ "error": [f"Expected dict or list, but got {type(response.json())}"],
f"Expected dict or list, but got {type(response.json())}"
],
"error_type": "executable_checker_rest:wrong_type", "error_type": "executable_checker_rest:wrong_type",
} }
except Exception as e: except Exception as e:
@ -941,9 +909,7 @@ def executable_checker_rest(func_call, idx):
else: else:
return { return {
"valid": False, "valid": False,
"error": [ "error": [f"Execution result status code is not 200, got {response.status_code}"],
f"Execution result status code is not 200, got {response.status_code}"
],
"error_type": "executable_checker_rest:wrong_status_code", "error_type": "executable_checker_rest:wrong_status_code",
} }
except Exception as e: except Exception as e:
@ -954,18 +920,12 @@ def executable_checker_rest(func_call, idx):
} }
def ast_checker( def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
func_description, model_output, possible_answer, language, test_category, model_name
):
if "parallel" in test_category: if "parallel" in test_category:
return parallel_function_checker_no_order( return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
func_description, model_output, possible_answer, language, model_name
)
elif "multiple" in test_category: elif "multiple" in test_category:
return multiple_function_checker( return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
func_description, model_output, possible_answer, language, model_name
)
else: else:
if len(model_output) != 1: if len(model_output) != 1:

View file

@ -216,6 +216,24 @@ datasets:
split: test split: test
dataset_id: math_500 dataset_id: math_500
provider_id: huggingface provider_id: huggingface
- dataset_schema:
function:
type: string
language:
type: string
ground_truth:
type: string
id:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
metadata:
path: llamastack/bfcl_v3
split: train
dataset_id: bfcl
provider_id: huggingface
scoring_fns: [] scoring_fns: []
benchmarks: benchmarks:
- dataset_id: simpleqa - dataset_id: simpleqa
@ -238,6 +256,11 @@ benchmarks:
- basic::regex_parser_math_response - basic::regex_parser_math_response
metadata: {} metadata: {}
benchmark_id: meta-reference-math-500 benchmark_id: meta-reference-math-500
- dataset_id: bfcl
scoring_functions:
- basic::bfcl
metadata: {}
benchmark_id: meta-reference-bfcl
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search