mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
feat: [new open benchmark] BFCL_v3
This commit is contained in:
parent
98b1b15e0f
commit
9dc717742a
9 changed files with 1577 additions and 6 deletions
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
||||||
from llama_stack.apis.benchmarks import Benchmark
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.inference import Inference, UserMessage
|
from llama_stack.apis.inference import Inference, UserMessage, SystemMessage
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||||
|
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
input_messages = [UserMessage(**x) for x in input_messages if x['role'] == 'user']
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||||
|
@ -168,10 +168,11 @@ class MetaReferenceEvalImpl(
|
||||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x['role'] == 'user']
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
|
messages += [SystemMessage(**x) for x in chat_completion_input_json if x['role'] == 'system']
|
||||||
messages += input_messages
|
messages += input_messages
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model_id=candidate.model,
|
model_id=candidate.model,
|
||||||
|
|
|
@ -26,8 +26,10 @@ from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||||
|
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||||
|
|
||||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
|
||||||
|
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn]
|
||||||
|
|
||||||
|
|
||||||
class BasicScoringImpl(
|
class BasicScoringImpl(
|
||||||
|
|
|
@ -0,0 +1,313 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
import ast
|
||||||
|
|
||||||
|
from .tree_sitter import get_parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_java_function_call(source_code):
|
||||||
|
if not source_code.endswith(";"):
|
||||||
|
source_code += ";" # Necessary for the parser not to register an error
|
||||||
|
parser = get_parser("java")
|
||||||
|
tree = parser.parse(bytes(source_code, "utf8"))
|
||||||
|
root_node = tree.root_node
|
||||||
|
|
||||||
|
if root_node.has_error:
|
||||||
|
raise Exception("Error parsing java the source code.")
|
||||||
|
|
||||||
|
def get_text(node):
|
||||||
|
"""Returns the text represented by the node."""
|
||||||
|
return source_code[node.start_byte : node.end_byte]
|
||||||
|
|
||||||
|
def traverse_node(node, nested=False):
|
||||||
|
if node.type == "string_literal":
|
||||||
|
if nested:
|
||||||
|
return get_text(node)
|
||||||
|
# Strip surrounding quotes from string literals
|
||||||
|
return get_text(node)[1:-1]
|
||||||
|
elif node.type == "character_literal":
|
||||||
|
if nested:
|
||||||
|
return get_text(node)
|
||||||
|
# Strip surrounding single quotes from character literals
|
||||||
|
return get_text(node)[1:-1]
|
||||||
|
"""Traverse the node to collect texts for complex structures."""
|
||||||
|
if node.type in [
|
||||||
|
"identifier",
|
||||||
|
"class_literal",
|
||||||
|
"type_identifier",
|
||||||
|
"method_invocation",
|
||||||
|
]:
|
||||||
|
return get_text(node)
|
||||||
|
elif node.type == "array_creation_expression":
|
||||||
|
# Handle array creation expression specifically
|
||||||
|
type_node = node.child_by_field_name("type")
|
||||||
|
value_node = node.child_by_field_name("value")
|
||||||
|
type_text = traverse_node(type_node, True)
|
||||||
|
value_text = traverse_node(value_node, True)
|
||||||
|
return f"new {type_text}[]{value_text}"
|
||||||
|
elif node.type == "object_creation_expression":
|
||||||
|
# Handle object creation expression specifically
|
||||||
|
type_node = node.child_by_field_name("type")
|
||||||
|
arguments_node = node.child_by_field_name("arguments")
|
||||||
|
type_text = traverse_node(type_node, True)
|
||||||
|
if arguments_node:
|
||||||
|
# Process each argument carefully, avoiding unnecessary punctuation
|
||||||
|
argument_texts = []
|
||||||
|
for child in arguments_node.children:
|
||||||
|
if child.type not in [
|
||||||
|
",",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
]: # Exclude commas and parentheses
|
||||||
|
argument_text = traverse_node(child, True)
|
||||||
|
argument_texts.append(argument_text)
|
||||||
|
arguments_text = ", ".join(argument_texts)
|
||||||
|
return f"new {type_text}({arguments_text})"
|
||||||
|
else:
|
||||||
|
return f"new {type_text}()"
|
||||||
|
elif node.type == "set":
|
||||||
|
# Handling sets specifically
|
||||||
|
items = [
|
||||||
|
traverse_node(n, True)
|
||||||
|
for n in node.children
|
||||||
|
if n.type not in [",", "set"]
|
||||||
|
]
|
||||||
|
return "{" + ", ".join(items) + "}"
|
||||||
|
|
||||||
|
elif node.child_count > 0:
|
||||||
|
return "".join(traverse_node(child, True) for child in node.children)
|
||||||
|
else:
|
||||||
|
return get_text(node)
|
||||||
|
|
||||||
|
def extract_arguments(args_node):
|
||||||
|
arguments = {}
|
||||||
|
for child in args_node.children:
|
||||||
|
if child.type == "assignment_expression":
|
||||||
|
# For named parameters
|
||||||
|
name_node, value_node = child.children[0], child.children[2]
|
||||||
|
name = get_text(name_node)
|
||||||
|
value = traverse_node(value_node)
|
||||||
|
if name in arguments:
|
||||||
|
if not isinstance(arguments[name], list):
|
||||||
|
arguments[name] = [arguments[name]]
|
||||||
|
arguments[name].append(value)
|
||||||
|
else:
|
||||||
|
arguments[name] = value
|
||||||
|
# arguments.append({'name': name, 'value': value})
|
||||||
|
elif child.type in ["identifier", "class_literal", "set"]:
|
||||||
|
# For unnamed parameters and handling sets
|
||||||
|
value = traverse_node(child)
|
||||||
|
if None in arguments:
|
||||||
|
if not isinstance(arguments[None], list):
|
||||||
|
arguments[None] = [arguments[None]]
|
||||||
|
arguments[None].append(value)
|
||||||
|
else:
|
||||||
|
arguments[None] = value
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
if node.type == "method_invocation":
|
||||||
|
# Extract the function name and its arguments
|
||||||
|
method_name = get_text(node.child_by_field_name("name"))
|
||||||
|
class_name_node = node.child_by_field_name("object")
|
||||||
|
if class_name_node:
|
||||||
|
class_name = get_text(class_name_node)
|
||||||
|
function_name = f"{class_name}.{method_name}"
|
||||||
|
else:
|
||||||
|
function_name = method_name
|
||||||
|
arguments_node = node.child_by_field_name("arguments")
|
||||||
|
if arguments_node:
|
||||||
|
arguments = extract_arguments(arguments_node)
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
raise Exception(
|
||||||
|
"Error: Multiple arguments with the same name are not supported."
|
||||||
|
)
|
||||||
|
return [{function_name: arguments}]
|
||||||
|
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
result = traverse(child)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = traverse(root_node)
|
||||||
|
return result if result else {}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_javascript_function_call(source_code):
|
||||||
|
if not source_code.endswith(";"):
|
||||||
|
source_code += ";" # Necessary for the parser not to register an error
|
||||||
|
parser = get_parser("javascript")
|
||||||
|
# Parse the source code
|
||||||
|
tree = parser.parse(bytes(source_code, "utf8"))
|
||||||
|
root_node = tree.root_node
|
||||||
|
if root_node.has_error:
|
||||||
|
raise Exception("Error js parsing the source code.")
|
||||||
|
|
||||||
|
# Function to recursively extract argument details
|
||||||
|
def extract_arguments(node):
|
||||||
|
args = {}
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "assignment_expression":
|
||||||
|
# Extract left (name) and right (value) parts of the assignment
|
||||||
|
name = child.children[0].text.decode("utf-8")
|
||||||
|
value = child.children[2].text.decode("utf-8")
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or (
|
||||||
|
value.startswith("'") and value.endswith("'")
|
||||||
|
):
|
||||||
|
value = value[1:-1] # Trim the quotation marks
|
||||||
|
if name in args:
|
||||||
|
if not isinstance(args[name], list):
|
||||||
|
args[name] = [args[name]]
|
||||||
|
args[name].append(value)
|
||||||
|
else:
|
||||||
|
args[name] = value
|
||||||
|
|
||||||
|
elif child.type == "identifier" or child.type == "true":
|
||||||
|
# Handle non-named arguments and boolean values
|
||||||
|
value = child.text.decode("utf-8")
|
||||||
|
if None in args:
|
||||||
|
if not isinstance(args[None], list):
|
||||||
|
args[None] = [args[None]]
|
||||||
|
args[None].append(value)
|
||||||
|
else:
|
||||||
|
args[None] = value
|
||||||
|
return args
|
||||||
|
|
||||||
|
# Find the function call and extract its name and arguments
|
||||||
|
if root_node.type == "program":
|
||||||
|
for child in root_node.children:
|
||||||
|
if child.type == "expression_statement":
|
||||||
|
for sub_child in child.children:
|
||||||
|
if sub_child.type == "call_expression":
|
||||||
|
function_name = sub_child.children[0].text.decode("utf8")
|
||||||
|
arguments_node = sub_child.children[1]
|
||||||
|
parameters = extract_arguments(arguments_node)
|
||||||
|
for key, value in parameters.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
raise Exception(
|
||||||
|
"Error: Multiple arguments with the same name are not supported."
|
||||||
|
)
|
||||||
|
result = [{function_name: parameters}]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def ast_parse(input_str, language="Python"):
|
||||||
|
if language == "Python":
|
||||||
|
cleaned_input = input_str.strip("[]'")
|
||||||
|
parsed = ast.parse(cleaned_input, mode="eval")
|
||||||
|
extracted = []
|
||||||
|
if isinstance(parsed.body, ast.Call):
|
||||||
|
extracted.append(resolve_ast_call(parsed.body))
|
||||||
|
else:
|
||||||
|
for elem in parsed.body.elts:
|
||||||
|
extracted.append(resolve_ast_call(elem))
|
||||||
|
return extracted
|
||||||
|
elif language == "Java":
|
||||||
|
return parse_java_function_call(
|
||||||
|
input_str[1:-1]
|
||||||
|
) # Remove the [ and ] from the string
|
||||||
|
elif language == "JavaScript":
|
||||||
|
return parse_javascript_function_call(input_str[1:-1])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ast_call(elem):
|
||||||
|
# Handle nested attributes for deeply nested module paths
|
||||||
|
func_parts = []
|
||||||
|
func_part = elem.func
|
||||||
|
while isinstance(func_part, ast.Attribute):
|
||||||
|
func_parts.append(func_part.attr)
|
||||||
|
func_part = func_part.value
|
||||||
|
if isinstance(func_part, ast.Name):
|
||||||
|
func_parts.append(func_part.id)
|
||||||
|
func_name = ".".join(reversed(func_parts))
|
||||||
|
args_dict = {}
|
||||||
|
# Parse when args are simply passed as an unnamed dictionary arg
|
||||||
|
for arg in elem.args:
|
||||||
|
if isinstance(arg, ast.Dict):
|
||||||
|
for key, value in zip(arg.keys, arg.values):
|
||||||
|
if isinstance(key, ast.Constant):
|
||||||
|
arg_name = key.value
|
||||||
|
output = resolve_ast_by_type(value)
|
||||||
|
args_dict[arg_name] = output
|
||||||
|
for arg in elem.keywords:
|
||||||
|
output = resolve_ast_by_type(arg.value)
|
||||||
|
args_dict[arg.arg] = output
|
||||||
|
return {func_name: args_dict}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ast_by_type(value):
|
||||||
|
if isinstance(value, ast.Constant):
|
||||||
|
if value.value is Ellipsis:
|
||||||
|
output = "..."
|
||||||
|
else:
|
||||||
|
output = value.value
|
||||||
|
elif isinstance(value, ast.UnaryOp):
|
||||||
|
output = -value.operand.value
|
||||||
|
elif isinstance(value, ast.List):
|
||||||
|
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||||
|
elif isinstance(value, ast.Dict):
|
||||||
|
output = {
|
||||||
|
resolve_ast_by_type(k): resolve_ast_by_type(v)
|
||||||
|
for k, v in zip(value.keys, value.values)
|
||||||
|
}
|
||||||
|
elif isinstance(
|
||||||
|
value, ast.NameConstant
|
||||||
|
): # Added this condition to handle boolean values
|
||||||
|
output = value.value
|
||||||
|
elif isinstance(
|
||||||
|
value, ast.BinOp
|
||||||
|
): # Added this condition to handle function calls as arguments
|
||||||
|
output = eval(ast.unparse(value))
|
||||||
|
elif isinstance(value, ast.Name):
|
||||||
|
output = value.id
|
||||||
|
elif isinstance(value, ast.Call):
|
||||||
|
if len(value.keywords) == 0:
|
||||||
|
output = ast.unparse(value)
|
||||||
|
else:
|
||||||
|
output = resolve_ast_call(value)
|
||||||
|
elif isinstance(value, ast.Tuple):
|
||||||
|
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||||
|
elif isinstance(value, ast.Lambda):
|
||||||
|
output = eval(ast.unparse(value.body[0].value))
|
||||||
|
elif isinstance(value, ast.Ellipsis):
|
||||||
|
output = "..."
|
||||||
|
elif isinstance(value, ast.Subscript):
|
||||||
|
try:
|
||||||
|
output = ast.unparse(value.body[0].value)
|
||||||
|
except:
|
||||||
|
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def decode_ast(result, language="Python"):
|
||||||
|
func = result
|
||||||
|
func = func.replace("\n", "") # remove new line characters
|
||||||
|
if not func.startswith("["):
|
||||||
|
func = "[" + func
|
||||||
|
if not func.endswith("]"):
|
||||||
|
func = func + "]"
|
||||||
|
decoded_output = ast_parse(func, language)
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
|
||||||
|
def decode_execute(result):
|
||||||
|
func = result
|
||||||
|
func = func.replace("\n", "") # remove new line characters
|
||||||
|
if not func.startswith("["):
|
||||||
|
func = "[" + func
|
||||||
|
if not func.endswith("]"):
|
||||||
|
func = func + "]"
|
||||||
|
decode_output = ast_parse(func)
|
||||||
|
execution_list = []
|
||||||
|
for function_call in decode_output:
|
||||||
|
for key, value in function_call.items():
|
||||||
|
execution_list.append(
|
||||||
|
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
|
||||||
|
)
|
||||||
|
return execution_list
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||||
|
import it from here so that we can centrally manage things as necessary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These currently work with tree-sitter 0.23.0
|
||||||
|
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||||
|
# because not all environments have them. Import lazily inside functions where needed.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import typing
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
import tree_sitter
|
||||||
|
|
||||||
|
|
||||||
|
def get_language(language: str) -> "tree_sitter.Language":
|
||||||
|
import tree_sitter
|
||||||
|
|
||||||
|
language_module_name = f"tree_sitter_{language}"
|
||||||
|
try:
|
||||||
|
language_module = importlib.import_module(language_module_name)
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||||
|
) from exc
|
||||||
|
return tree_sitter.Language(language_module.language())
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||||
|
import tree_sitter
|
||||||
|
|
||||||
|
lang = get_language(language)
|
||||||
|
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.bfcl import bfcl
|
||||||
|
from .bfcl.ast_parser import decode_ast
|
||||||
|
from .bfcl.checker import ast_checker, is_empty_output
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||||
|
contain_func_call = False
|
||||||
|
error = None
|
||||||
|
error_type = None
|
||||||
|
checker_result = {}
|
||||||
|
try:
|
||||||
|
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||||
|
contain_func_call = True
|
||||||
|
# if not is_function_calling_format_output(prediction):
|
||||||
|
if is_empty_output(prediction):
|
||||||
|
contain_func_call = False
|
||||||
|
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||||
|
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||||
|
else:
|
||||||
|
checker_result = ast_checker(
|
||||||
|
json.loads(x["function"]),
|
||||||
|
prediction,
|
||||||
|
json.loads(x["ground_truth"]),
|
||||||
|
x["language"],
|
||||||
|
test_category=test_category,
|
||||||
|
model_name="",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
prediction = ""
|
||||||
|
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||||
|
error_type = "ast_decoder:decoder_failed"
|
||||||
|
return {
|
||||||
|
"prediction": prediction,
|
||||||
|
"contain_func_call": contain_func_call,
|
||||||
|
"valid": checker_result.get("valid", False),
|
||||||
|
"error": error or checker_result.get("error", ""),
|
||||||
|
"error_type": error_type or checker_result.get("error_type", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
return {"valid": x["valid"]}
|
||||||
|
|
||||||
|
|
||||||
|
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||||
|
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||||
|
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||||
|
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||||
|
acc = (
|
||||||
|
not x["contain_func_call"]
|
||||||
|
if "irrelevance" in x["id"]
|
||||||
|
else x["contain_func_call"]
|
||||||
|
)
|
||||||
|
return {"valid": float(acc)}
|
||||||
|
|
||||||
|
|
||||||
|
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
A scoring_fn for BFCL
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
bfcl.identifier: bfcl,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
test_category = re.sub(r'_[0-9_-]+$', '', input_row['id'])
|
||||||
|
score_result = postprocess(input_row, test_category)
|
||||||
|
if (test_category in set(['irrelevance', 'live_relevance', 'live_irrelevance'])):
|
||||||
|
score = gen_relevance_acc(score_result)['valid']
|
||||||
|
else:
|
||||||
|
score = gen_valid(score_result)['valid']
|
||||||
|
return {
|
||||||
|
"score": float(score),
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
bfcl = ScoringFn(
|
||||||
|
identifier="basic::bfcl",
|
||||||
|
description="BFCL complex scoring",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="basic",
|
||||||
|
provider_resource_id="bfcl",
|
||||||
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||||
|
)
|
|
@ -118,9 +118,30 @@ models:
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets:
|
||||||
|
- dataset_id: bfcl
|
||||||
|
provider_id: huggingface
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/llamastack/bfcl_v3/tree/main
|
||||||
|
metadata:
|
||||||
|
path: llamastack/bfcl_v3
|
||||||
|
split: train
|
||||||
|
dataset_schema:
|
||||||
|
function:
|
||||||
|
type: string
|
||||||
|
language:
|
||||||
|
type: string
|
||||||
|
ground_truth:
|
||||||
|
type: string
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
chat_completion_input:
|
||||||
|
type: string
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
benchmarks: []
|
benchmarks:
|
||||||
|
- benchmark_id: bfcl
|
||||||
|
dataset_id: bfcl
|
||||||
|
scoring_functions: ["basic::bfcl"]
|
||||||
tool_groups:
|
tool_groups:
|
||||||
- toolgroup_id: builtin::websearch
|
- toolgroup_id: builtin::websearch
|
||||||
provider_id: tavily-search
|
provider_id: tavily-search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue