mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
precommit local run
This commit is contained in:
parent
32d025923c
commit
3c560609dd
6 changed files with 71 additions and 109 deletions
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage, SystemMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
|
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x['role'] == 'user']
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
@ -168,11 +168,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x['role'] == 'user']
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x['role'] == 'system']
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
@ -22,14 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -4,18 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.bfcl import bfcl
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
import json
|
||||
import re
|
||||
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
|
@ -62,11 +61,7 @@ def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
|||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = (
|
||||
not x["contain_func_call"]
|
||||
if "irrelevance" in x["id"]
|
||||
else x["contain_func_call"]
|
||||
)
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
|
@ -87,12 +82,12 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r'_[0-9_-]+$', '', input_row['id'])
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if (test_category in {'irrelevance', 'live_relevance', 'live_irrelevance'}):
|
||||
score = gen_relevance_acc(score_result)['valid']
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)['valid']
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#ruff: noqa
|
||||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -71,11 +71,7 @@ def parse_java_function_call(source_code):
|
|||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [
|
||||
traverse_node(n, True)
|
||||
for n in node.children
|
||||
if n.type not in [",", "set"]
|
||||
]
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
|
@ -124,9 +120,7 @@ def parse_java_function_call(source_code):
|
|||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception(
|
||||
"Error: Multiple arguments with the same name are not supported."
|
||||
)
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
|
@ -157,9 +151,7 @@ def parse_javascript_function_call(source_code):
|
|||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
|
@ -190,9 +182,7 @@ def parse_javascript_function_call(source_code):
|
|||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception(
|
||||
"Error: Multiple arguments with the same name are not supported."
|
||||
)
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
@ -209,9 +199,7 @@ def ast_parse(input_str, language="Python"):
|
|||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(
|
||||
input_str[1:-1]
|
||||
) # Remove the [ and ] from the string
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
|
@ -254,17 +242,10 @@ def resolve_ast_by_type(value):
|
|||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {
|
||||
resolve_ast_by_type(k): resolve_ast_by_type(v)
|
||||
for k, v in zip(value.keys, value.values)
|
||||
}
|
||||
elif isinstance(
|
||||
value, ast.NameConstant
|
||||
): # Added this condition to handle boolean values
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(
|
||||
value, ast.BinOp
|
||||
): # Added this condition to handle function calls as arguments
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
|
@ -311,7 +292,5 @@ def decode_execute(result):
|
|||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(
|
||||
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
|
||||
)
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#ruff: noqa
|
||||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -220,9 +220,7 @@ def list_checker(param: str, model_output: list, possible_answer: list):
|
|||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(
|
||||
standardize_string(possible_answer[i][j])
|
||||
)
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
|
@ -244,7 +242,6 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
|
|||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
|
@ -272,9 +269,7 @@ def dict_checker(param: str, model_output: dict, possible_answers: list):
|
|||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(
|
||||
standardize_string(possible_answer[key][i])
|
||||
)
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
|
@ -353,7 +348,6 @@ def simple_function_checker(
|
|||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
|
@ -403,9 +397,7 @@ def simple_function_checker(
|
|||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(
|
||||
value, expected_type_description, nested_type
|
||||
)
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
|
@ -426,9 +418,7 @@ def simple_function_checker(
|
|||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(
|
||||
value, expected_type_description, nested_type
|
||||
)
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
|
@ -445,11 +435,7 @@ def simple_function_checker(
|
|||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if (
|
||||
language == "Python"
|
||||
and expected_type_description == "float"
|
||||
and type(value) == int
|
||||
):
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
|
@ -609,9 +595,7 @@ def parallel_function_checker_no_order(
|
|||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [
|
||||
i for i in range(len(model_output)) if i not in matched_indices
|
||||
]
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
|
@ -782,9 +766,7 @@ def executable_checker_simple(
|
|||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(
|
||||
exec_output, expected_result, function_call, is_sanity_check
|
||||
)
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
|
@ -794,7 +776,6 @@ def executable_checker_simple(
|
|||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
|
@ -828,18 +809,14 @@ def executable_checker_parallel_no_order(
|
|||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"]
|
||||
if "model_executed_output" in result
|
||||
else None
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [
|
||||
i for i in range(len(decoded_result)) if i not in matched_indices
|
||||
]
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
|
@ -874,7 +851,6 @@ def executable_checker_rest(func_call, idx):
|
|||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
|
@ -888,9 +864,7 @@ def executable_checker_rest(func_call, idx):
|
|||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Expected dictionary, but got {type(response.json())}"
|
||||
],
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
|
@ -905,9 +879,7 @@ def executable_checker_rest(func_call, idx):
|
|||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(
|
||||
response.json()[i].keys()
|
||||
):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
|
@ -918,16 +890,12 @@ def executable_checker_rest(func_call, idx):
|
|||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Expected list, but got {type(response.json())}"
|
||||
],
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Expected dict or list, but got {type(response.json())}"
|
||||
],
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
|
@ -941,9 +909,7 @@ def executable_checker_rest(func_call, idx):
|
|||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Execution result status code is not 200, got {response.status_code}"
|
||||
],
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
|
@ -954,18 +920,12 @@ def executable_checker_rest(func_call, idx):
|
|||
}
|
||||
|
||||
|
||||
def ast_checker(
|
||||
func_description, model_output, possible_answer, language, test_category, model_name
|
||||
):
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(
|
||||
func_description, model_output, possible_answer, language, model_name
|
||||
)
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(
|
||||
func_description, model_output, possible_answer, language, model_name
|
||||
)
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
|
|
|
@ -216,6 +216,24 @@ datasets:
|
|||
split: test
|
||||
dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
function:
|
||||
type: string
|
||||
language:
|
||||
type: string
|
||||
ground_truth:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
|
||||
metadata:
|
||||
path: llamastack/bfcl_v3
|
||||
split: train
|
||||
dataset_id: bfcl
|
||||
provider_id: huggingface
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- dataset_id: simpleqa
|
||||
|
@ -238,6 +256,11 @@ benchmarks:
|
|||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
- dataset_id: bfcl
|
||||
scoring_functions:
|
||||
- basic::bfcl
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-bfcl
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue