precommit local run

This commit is contained in:
Yang Yang 2025-03-12 13:47:36 -07:00
parent 32d025923c
commit 3c560609dd
6 changed files with 71 additions and 109 deletions

View file

@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, UserMessage, SystemMessage
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x['role'] == 'user']
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
@ -168,11 +168,11 @@ class MetaReferenceEvalImpl(
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x['role'] == 'user']
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x['role'] == 'system']
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,

View file

@ -22,14 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
)
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn]
FIXED_FNS = [
EqualityScoringFn,
SubsetOfScoringFn,
RegexParserScoringFn,
RegexParserMathResponseScoringFn,
BFCLScoringFn,
]
class BasicScoringImpl(

View file

@ -4,18 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.bfcl import bfcl
from ..utils.bfcl.ast_parser import decode_ast
from ..utils.bfcl.checker import ast_checker, is_empty_output
import json
import re
from .fn_defs.bfcl import bfcl
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
@ -62,11 +61,7 @@ def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
acc = (
not x["contain_func_call"]
if "irrelevance" in x["id"]
else x["contain_func_call"]
)
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
return {"valid": float(acc)}
@ -87,12 +82,12 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
scoring_fn_identifier: Optional[str] = "bfcl",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
test_category = re.sub(r'_[0-9_-]+$', '', input_row['id'])
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)
if (test_category in {'irrelevance', 'live_relevance', 'live_irrelevance'}):
score = gen_relevance_acc(score_result)['valid']
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
score = gen_relevance_acc(score_result)["valid"]
else:
score = gen_valid(score_result)['valid']
score = gen_valid(score_result)["valid"]
return {
"score": float(score),
}

View file

@ -1,4 +1,4 @@
#ruff: noqa
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -71,11 +71,7 @@ def parse_java_function_call(source_code):
return f"new {type_text}()"
elif node.type == "set":
# Handling sets specifically
items = [
traverse_node(n, True)
for n in node.children
if n.type not in [",", "set"]
]
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
return "{" + ", ".join(items) + "}"
elif node.child_count > 0:
@ -124,9 +120,7 @@ def parse_java_function_call(source_code):
arguments = extract_arguments(arguments_node)
for key, value in arguments.items():
if isinstance(value, list):
raise Exception(
"Error: Multiple arguments with the same name are not supported."
)
raise Exception("Error: Multiple arguments with the same name are not supported.")
return [{function_name: arguments}]
else:
@ -157,9 +151,7 @@ def parse_javascript_function_call(source_code):
# Extract left (name) and right (value) parts of the assignment
name = child.children[0].text.decode("utf-8")
value = child.children[2].text.decode("utf-8")
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
value = value[1:-1] # Trim the quotation marks
if name in args:
if not isinstance(args[name], list):
@ -190,9 +182,7 @@ def parse_javascript_function_call(source_code):
parameters = extract_arguments(arguments_node)
for key, value in parameters.items():
if isinstance(value, list):
raise Exception(
"Error: Multiple arguments with the same name are not supported."
)
raise Exception("Error: Multiple arguments with the same name are not supported.")
result = [{function_name: parameters}]
return result
@ -209,9 +199,7 @@ def ast_parse(input_str, language="Python"):
extracted.append(resolve_ast_call(elem))
return extracted
elif language == "Java":
return parse_java_function_call(
input_str[1:-1]
) # Remove the [ and ] from the string
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
elif language == "JavaScript":
return parse_javascript_function_call(input_str[1:-1])
else:
@ -254,17 +242,10 @@ def resolve_ast_by_type(value):
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {
resolve_ast_by_type(k): resolve_ast_by_type(v)
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value, ast.NameConstant
): # Added this condition to handle boolean values
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(
value, ast.BinOp
): # Added this condition to handle function calls as arguments
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value))
elif isinstance(value, ast.Name):
output = value.id
@ -311,7 +292,5 @@ def decode_execute(result):
execution_list = []
for function_call in decode_output:
for key, value in function_call.items():
execution_list.append(
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
)
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
return execution_list

View file

@ -1,4 +1,4 @@
#ruff: noqa
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -220,9 +220,7 @@ def list_checker(param: str, model_output: list, possible_answer: list):
standardize_possible_answer.append([])
for j in range(len(possible_answer[i])):
if type(possible_answer[i][j]) == str:
standardize_possible_answer[i].append(
standardize_string(possible_answer[i][j])
)
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
else:
standardize_possible_answer[i].append(possible_answer[i][j])
@ -244,7 +242,6 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
for i in range(len(possible_answers)):
if possible_answers[i] == "":
continue
@ -272,9 +269,7 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
standardize_possible_answer = []
for i in range(len(possible_answer[key])):
if type(possible_answer[key][i]) == str:
standardize_possible_answer.append(
standardize_string(possible_answer[key][i])
)
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
else:
standardize_possible_answer.append(possible_answer[key][i])
@ -353,7 +348,6 @@ def simple_function_checker(
"error_type": "simple_function_checker:unclear",
}
# Check if function name matches
if func_name not in model_output:
result["valid"] = False
@ -403,9 +397,7 @@ def simple_function_checker(
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
value = java_type_converter(
value, expected_type_description, nested_type
)
value = java_type_converter(value, expected_type_description, nested_type)
else:
value = java_type_converter(value, expected_type_description)
@ -426,9 +418,7 @@ def simple_function_checker(
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
value = js_type_converter(
value, expected_type_description, nested_type
)
value = js_type_converter(value, expected_type_description, nested_type)
else:
value = js_type_converter(value, expected_type_description)
@ -445,11 +435,7 @@ def simple_function_checker(
value = list(value)
# Allow python auto conversion from int to float
if (
language == "Python"
and expected_type_description == "float"
and type(value) == int
):
if language == "Python" and expected_type_description == "float" and type(value) == int:
value = float(value)
# Type checking
@ -609,9 +595,7 @@ def parallel_function_checker_no_order(
)
if not result["valid"]:
considered_indices = [
i for i in range(len(model_output)) if i not in matched_indices
]
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
@ -782,9 +766,7 @@ def executable_checker_simple(
else:
# structural match
pattern_match_result = patten_matcher(
exec_output, expected_result, function_call, is_sanity_check
)
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
if not pattern_match_result["valid"]:
return pattern_match_result
@ -794,7 +776,6 @@ def executable_checker_simple(
def executable_checker_parallel_no_order(
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
):
if len(decoded_result) != len(expected_exec_result):
return {
"valid": False,
@ -828,18 +809,14 @@ def executable_checker_parallel_no_order(
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_executed_output": (
result["model_executed_output"]
if "model_executed_output" in result
else None
result["model_executed_output"] if "model_executed_output" in result else None
),
}
}
)
if not result["valid"]:
considered_indices = [
i for i in range(len(decoded_result)) if i not in matched_indices
]
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
@ -874,7 +851,6 @@ def executable_checker_rest(func_call, idx):
try:
if response.status_code == 200:
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
try:
if isinstance(eval_GT_json, dict):
@ -888,9 +864,7 @@ def executable_checker_rest(func_call, idx):
}
return {
"valid": False,
"error": [
f"Expected dictionary, but got {type(response.json())}"
],
"error": [f"Expected dictionary, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
@ -905,9 +879,7 @@ def executable_checker_rest(func_call, idx):
else:
for i in range(len(eval_GT_json)):
if set(eval_GT_json[i].keys()) != set(
response.json()[i].keys()
):
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
return {
"valid": False,
"error": [f"Key inconsistency"],
@ -918,16 +890,12 @@ def executable_checker_rest(func_call, idx):
else:
return {
"valid": False,
"error": [
f"Expected list, but got {type(response.json())}"
],
"error": [f"Expected list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
return {
"valid": False,
"error": [
f"Expected dict or list, but got {type(response.json())}"
],
"error": [f"Expected dict or list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
except Exception as e:
@ -941,9 +909,7 @@ def executable_checker_rest(func_call, idx):
else:
return {
"valid": False,
"error": [
f"Execution result status code is not 200, got {response.status_code}"
],
"error": [f"Execution result status code is not 200, got {response.status_code}"],
"error_type": "executable_checker_rest:wrong_status_code",
}
except Exception as e:
@ -954,18 +920,12 @@ def executable_checker_rest(func_call, idx):
}
def ast_checker(
func_description, model_output, possible_answer, language, test_category, model_name
):
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
if "parallel" in test_category:
return parallel_function_checker_no_order(
func_description, model_output, possible_answer, language, model_name
)
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
elif "multiple" in test_category:
return multiple_function_checker(
func_description, model_output, possible_answer, language, model_name
)
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
else:
if len(model_output) != 1:

View file

@ -216,6 +216,24 @@ datasets:
split: test
dataset_id: math_500
provider_id: huggingface
- dataset_schema:
function:
type: string
language:
type: string
ground_truth:
type: string
id:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
metadata:
path: llamastack/bfcl_v3
split: train
dataset_id: bfcl
provider_id: huggingface
scoring_fns: []
benchmarks:
- dataset_id: simpleqa
@ -238,6 +256,11 @@ benchmarks:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
- dataset_id: bfcl
scoring_functions:
- basic::bfcl
metadata: {}
benchmark_id: meta-reference-bfcl
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search