From 9dc717742aaf764d83558ae8113395de610cfa9b Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 11 Mar 2025 19:18:14 -0700 Subject: [PATCH] feat: [new open benchmark] BFCL_v3 --- .../inline/eval/meta_reference/eval.py | 7 +- .../providers/inline/scoring/basic/scoring.py | 4 +- .../scoring/basic/scoring_fn/bfcl/__init__.py | 0 .../basic/scoring_fn/bfcl/ast_parser.py | 313 +++++ .../scoring/basic/scoring_fn/bfcl/checker.py | 1078 +++++++++++++++++ .../basic/scoring_fn/bfcl/tree_sitter.py | 37 + .../basic/scoring_fn/bfcl_scoring_fn.py | 98 ++ .../scoring/basic/scoring_fn/fn_defs/bfcl.py | 21 + llama_stack/templates/ollama/run.yaml | 25 +- 9 files changed, 1577 insertions(+), 6 deletions(-) create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/__init__.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/ast_parser.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/checker.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/tree_sitter.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index a1bebaa4c..7616563bb 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.inference import Inference, UserMessage +from llama_stack.apis.inference import Inference, UserMessage, SystemMessage from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import BenchmarksProtocolPrivate @@ -118,7 +118,7 @@ class MetaReferenceEvalImpl( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages] + input_messages = [UserMessage(**x) for x in input_messages if x['role'] == 'user'] # NOTE: only single-turn agent generation is supported. Create a new session for each input row session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") @@ -168,10 +168,11 @@ class MetaReferenceEvalImpl( generations.append({ColumnName.generated_answer.value: response.completion_message.content}) elif ColumnName.chat_completion_input.value in x: chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json] + input_messages = [UserMessage(**x) for x in chat_completion_input_json if x['role'] == 'user'] messages = [] if candidate.system_message: messages.append(candidate.system_message) + messages += [SystemMessage(**x) for x in chat_completion_input_json if x['role'] == 'system'] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 00945b99d..ffc0a95bb 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -26,8 +26,10 @@ from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn +from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn] + +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn] class BasicScoringImpl( diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/__init__.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/ast_parser.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/ast_parser.py new file mode 100644 index 000000000..595dfc95a --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/ast_parser.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import ast + +from .tree_sitter import get_parser + + +def parse_java_function_call(source_code): + if not source_code.endswith(";"): + source_code += ";" # Necessary for the parser not to register an error + parser = get_parser("java") + tree = parser.parse(bytes(source_code, "utf8")) + root_node = tree.root_node + + if root_node.has_error: + raise Exception("Error parsing java the source code.") + + def get_text(node): + """Returns the text represented by the node.""" + return source_code[node.start_byte : node.end_byte] + + def traverse_node(node, nested=False): + if node.type == "string_literal": + if nested: + return get_text(node) + # Strip surrounding quotes from string literals + return get_text(node)[1:-1] + elif node.type == "character_literal": + if nested: + return get_text(node) + # Strip surrounding single quotes from character literals + return get_text(node)[1:-1] + """Traverse the node to collect texts for complex structures.""" + if node.type in [ + "identifier", + "class_literal", + "type_identifier", + "method_invocation", + ]: + return get_text(node) + elif node.type == "array_creation_expression": + # Handle array creation expression specifically + type_node = node.child_by_field_name("type") + value_node = node.child_by_field_name("value") + type_text = traverse_node(type_node, True) + value_text = traverse_node(value_node, True) + return f"new {type_text}[]{value_text}" + elif node.type == "object_creation_expression": + # Handle object creation expression specifically + type_node = node.child_by_field_name("type") + arguments_node = node.child_by_field_name("arguments") + type_text = traverse_node(type_node, True) + if arguments_node: + # Process each argument carefully, avoiding unnecessary punctuation + argument_texts = [] + for child in arguments_node.children: + if child.type not in [ + ",", + "(", + ")", + ]: # Exclude commas and parentheses + argument_text = traverse_node(child, True) + argument_texts.append(argument_text) + arguments_text = ", ".join(argument_texts) + return f"new {type_text}({arguments_text})" + else: + return f"new {type_text}()" + elif node.type == "set": + # Handling sets specifically + items = [ + traverse_node(n, True) + for n in node.children + if n.type not in [",", "set"] + ] + return "{" + ", ".join(items) + "}" + + elif node.child_count > 0: + return "".join(traverse_node(child, True) for child in node.children) + else: + return get_text(node) + + def extract_arguments(args_node): + arguments = {} + for child in args_node.children: + if child.type == "assignment_expression": + # For named parameters + name_node, value_node = child.children[0], child.children[2] + name = get_text(name_node) + value = traverse_node(value_node) + if name in arguments: + if not isinstance(arguments[name], list): + arguments[name] = [arguments[name]] + arguments[name].append(value) + else: + arguments[name] = value + # arguments.append({'name': name, 'value': value}) + elif child.type in ["identifier", "class_literal", "set"]: + # For unnamed parameters and handling sets + value = traverse_node(child) + if None in arguments: + if not isinstance(arguments[None], list): + arguments[None] = [arguments[None]] + arguments[None].append(value) + else: + arguments[None] = value + return arguments + + def traverse(node): + if node.type == "method_invocation": + # Extract the function name and its arguments + method_name = get_text(node.child_by_field_name("name")) + class_name_node = node.child_by_field_name("object") + if class_name_node: + class_name = get_text(class_name_node) + function_name = f"{class_name}.{method_name}" + else: + function_name = method_name + arguments_node = node.child_by_field_name("arguments") + if arguments_node: + arguments = extract_arguments(arguments_node) + for key, value in arguments.items(): + if isinstance(value, list): + raise Exception( + "Error: Multiple arguments with the same name are not supported." + ) + return [{function_name: arguments}] + + else: + for child in node.children: + result = traverse(child) + if result: + return result + + result = traverse(root_node) + return result if result else {} + + +def parse_javascript_function_call(source_code): + if not source_code.endswith(";"): + source_code += ";" # Necessary for the parser not to register an error + parser = get_parser("javascript") + # Parse the source code + tree = parser.parse(bytes(source_code, "utf8")) + root_node = tree.root_node + if root_node.has_error: + raise Exception("Error js parsing the source code.") + + # Function to recursively extract argument details + def extract_arguments(node): + args = {} + for child in node.children: + if child.type == "assignment_expression": + # Extract left (name) and right (value) parts of the assignment + name = child.children[0].text.decode("utf-8") + value = child.children[2].text.decode("utf-8") + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + value = value[1:-1] # Trim the quotation marks + if name in args: + if not isinstance(args[name], list): + args[name] = [args[name]] + args[name].append(value) + else: + args[name] = value + + elif child.type == "identifier" or child.type == "true": + # Handle non-named arguments and boolean values + value = child.text.decode("utf-8") + if None in args: + if not isinstance(args[None], list): + args[None] = [args[None]] + args[None].append(value) + else: + args[None] = value + return args + + # Find the function call and extract its name and arguments + if root_node.type == "program": + for child in root_node.children: + if child.type == "expression_statement": + for sub_child in child.children: + if sub_child.type == "call_expression": + function_name = sub_child.children[0].text.decode("utf8") + arguments_node = sub_child.children[1] + parameters = extract_arguments(arguments_node) + for key, value in parameters.items(): + if isinstance(value, list): + raise Exception( + "Error: Multiple arguments with the same name are not supported." + ) + result = [{function_name: parameters}] + return result + + +def ast_parse(input_str, language="Python"): + if language == "Python": + cleaned_input = input_str.strip("[]'") + parsed = ast.parse(cleaned_input, mode="eval") + extracted = [] + if isinstance(parsed.body, ast.Call): + extracted.append(resolve_ast_call(parsed.body)) + else: + for elem in parsed.body.elts: + extracted.append(resolve_ast_call(elem)) + return extracted + elif language == "Java": + return parse_java_function_call( + input_str[1:-1] + ) # Remove the [ and ] from the string + elif language == "JavaScript": + return parse_javascript_function_call(input_str[1:-1]) + else: + raise NotImplementedError(f"Unsupported language: {language}") + + +def resolve_ast_call(elem): + # Handle nested attributes for deeply nested module paths + func_parts = [] + func_part = elem.func + while isinstance(func_part, ast.Attribute): + func_parts.append(func_part.attr) + func_part = func_part.value + if isinstance(func_part, ast.Name): + func_parts.append(func_part.id) + func_name = ".".join(reversed(func_parts)) + args_dict = {} + # Parse when args are simply passed as an unnamed dictionary arg + for arg in elem.args: + if isinstance(arg, ast.Dict): + for key, value in zip(arg.keys, arg.values): + if isinstance(key, ast.Constant): + arg_name = key.value + output = resolve_ast_by_type(value) + args_dict[arg_name] = output + for arg in elem.keywords: + output = resolve_ast_by_type(arg.value) + args_dict[arg.arg] = output + return {func_name: args_dict} + + +def resolve_ast_by_type(value): + if isinstance(value, ast.Constant): + if value.value is Ellipsis: + output = "..." + else: + output = value.value + elif isinstance(value, ast.UnaryOp): + output = -value.operand.value + elif isinstance(value, ast.List): + output = [resolve_ast_by_type(v) for v in value.elts] + elif isinstance(value, ast.Dict): + output = { + resolve_ast_by_type(k): resolve_ast_by_type(v) + for k, v in zip(value.keys, value.values) + } + elif isinstance( + value, ast.NameConstant + ): # Added this condition to handle boolean values + output = value.value + elif isinstance( + value, ast.BinOp + ): # Added this condition to handle function calls as arguments + output = eval(ast.unparse(value)) + elif isinstance(value, ast.Name): + output = value.id + elif isinstance(value, ast.Call): + if len(value.keywords) == 0: + output = ast.unparse(value) + else: + output = resolve_ast_call(value) + elif isinstance(value, ast.Tuple): + output = tuple(resolve_ast_by_type(v) for v in value.elts) + elif isinstance(value, ast.Lambda): + output = eval(ast.unparse(value.body[0].value)) + elif isinstance(value, ast.Ellipsis): + output = "..." + elif isinstance(value, ast.Subscript): + try: + output = ast.unparse(value.body[0].value) + except: + output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]" + else: + raise Exception(f"Unsupported AST type: {type(value)}") + return output + + +def decode_ast(result, language="Python"): + func = result + func = func.replace("\n", "") # remove new line characters + if not func.startswith("["): + func = "[" + func + if not func.endswith("]"): + func = func + "]" + decoded_output = ast_parse(func, language) + return decoded_output + + +def decode_execute(result): + func = result + func = func.replace("\n", "") # remove new line characters + if not func.startswith("["): + func = "[" + func + if not func.endswith("]"): + func = func + "]" + decode_output = ast_parse(func) + execution_list = [] + for function_call in decode_output: + for key, value in function_call.items(): + execution_list.append( + f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})" + ) + return execution_list diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/checker.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/checker.py new file mode 100644 index 000000000..48abfcf15 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/checker.py @@ -0,0 +1,1078 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import re +import time +from typing import Any + +# Comment out for now until we actually use the rest checker in evals +# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function. + + +class NoAPIKeyError(Exception): + def __init__(self): + self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate." + super().__init__(self.message) + + +REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2 + +UNDERSCORE_TO_DOT = [ + "gpt-4o-2024-08-06-FC", + "gpt-4o-2024-05-13-FC", + "gpt-4o-mini-2024-07-18-FC", + "gpt-4-turbo-2024-04-09-FC", + "gpt-4-1106-preview-FC", + "gpt-4-0125-preview-FC", + "gpt-4-0613-FC", + "gpt-3.5-turbo-0125-FC", + "claude-3-opus-20240229-FC", + "claude-3-sonnet-20240229-FC", + "claude-3-haiku-20240307-FC", + "claude-3-5-sonnet-20240620-FC", + "open-mistral-nemo-2407-FC-Any", + "open-mistral-nemo-2407-FC-Auto", + "open-mixtral-8x22b-FC-Any", + "open-mixtral-8x22b-FC-Auto", + "mistral-large-2407-FC", + "mistral-large-2407-FC-Any", + "mistral-large-2407-FC-Auto", + "mistral-small-2402-FC-Any", + "mistral-small-2402-FC-Auto", + "mistral-small-2402-FC", + "gemini-1.0-pro", + "gemini-1.5-pro-preview-0409", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "meetkai/functionary-small-v3.1-FC", + "meetkai/functionary-small-v3.2-FC", + "meetkai/functionary-medium-v3.1-FC", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-2-Pro-Llama-3-70B", + "NousResearch/Hermes-2-Pro-Mistral-7B", + "NousResearch/Hermes-2-Theta-Llama-3-8B", + "NousResearch/Hermes-2-Theta-Llama-3-70B", + "command-r-plus-FC", + "command-r-plus-FC-optimized", + "THUDM/glm-4-9b-chat", + "ibm-granite/granite-20b-functioncalling", + "yi-large-fc", +] + +JAVA_TYPE_CONVERSION = { + "byte": int, + "short": int, + "integer": int, + "float": float, + "double": float, + "long": int, + "boolean": bool, + "char": str, + "Array": list, + "ArrayList": list, + "Set": set, + "HashMap": dict, + "Hashtable": dict, + "Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list + "Stack": list, + "String": str, + "any": str, +} + +JS_TYPE_CONVERSION = { + "String": str, + "integer": int, + "float": float, + "Bigint": int, + "Boolean": bool, + "dict": dict, + "array": list, + "any": str, +} + +# We switch to conditional import for the following two imports to avoid unnecessary installations. +# User doesn't need to setup the tree-sitter packages if they are not running the test for that language. +# from js_type_converter import js_type_converter +# from java_type_converter import java_type_converter + +PYTHON_TYPE_MAPPING = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "array": list, + "tuple": list, + "dict": dict, + "any": str, +} + +# This is the list of types that we need to recursively check its values +PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"] + + +NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"] + + +#### Helper functions for AST #### +def find_description(func_descriptions, name): + if type(func_descriptions) == list: + for func_description in func_descriptions: + if func_description["name"] == name: + return func_description + return None + else: + # it is a dict, there is only one function + return func_descriptions + + +def get_possible_answer_type(possible_answer: list): + for answer in possible_answer: + if answer != "": # Optional parameter + return type(answer) + return None + + +def convert_func_name(function_name, model_name: str): + model_name_escaped = model_name.replace("_", "/") + if "." in function_name: + if model_name_escaped in UNDERSCORE_TO_DOT: + # OAI does not support "." in the function name so we replace it with "_". ^[a-zA-Z0-9_-]{1,64}$ is the regex for the name. + # This happens for OpenAI, Mistral, and Google models + return re.sub(r"\.", "_", function_name) + return function_name + + +def type_checker( + param: str, + value, + possible_answer: list, + expected_type_description: str, + expected_type_converted, + nested_type_converted, +): + # NOTE: This type checker only supports nested type checking for one level deep. + # We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex. + + result: Any = { + "valid": True, + "error": [], + "is_variable": False, + "error_type": "type_error:simple", + } + + is_variable = False + # check for the case where a variable is used instead of a actual value. + # use the type in possible_answer as the expected type + possible_answer_type = get_possible_answer_type(possible_answer) + # if possible_answer only contains optional parameters, we can't determine the type + if possible_answer_type != None: + # we are being precise here. + # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer + if possible_answer_type != expected_type_converted: + is_variable = True + + # value is the same type as in function description + if type(value) == expected_type_converted: + # We don't need to do recursive check for simple types + if nested_type_converted == None: + result["is_variable"] = is_variable + return result + else: + for possible_answer_item in possible_answer: + flag = True # Each parameter should match to at least one possible answer type. + # Here, we assume that each item should be the same type. We could also relax it. + if type(possible_answer_item) == list: + for value_item in value: + checker_result = type_checker( + param, + value_item, + possible_answer_item, + str(nested_type_converted), + nested_type_converted, + None, + ) + if not checker_result["valid"]: + flag = False + break + + if flag: + return {"valid": True, "error": [], "is_variable": is_variable} + + result["valid"] = False + result["error"] = [ + f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}." + ] + result["error_type"] = "type_error:nested" + + # value is not as expected, check for the case where a variable is used instead of a actual value + # use the type in possible_answer as the expected type + possible_answer_type = get_possible_answer_type(possible_answer) + # if possible_answer only contains optional parameters, we can't determine the type + if possible_answer_type != None: + # we are being precise here. + # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer + if type(value) == possible_answer_type: + result["is_variable"] = True + return result + + result["valid"] = False + result["error"].append( + f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:simple" + return result + + +def standardize_string(input_string: str): + # This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase + # It will also convert all the single quotes to double quotes + # This is used to compare the model output with the possible answers + # We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024 + regex_string = r"[ \,\.\/\-\_\*\^]" + return re.sub(regex_string, "", input_string).lower().replace("'", '"') + + +def string_checker(param: str, model_output: str, possible_answer: list): + standardize_possible_answer = [] + standardize_model_output = standardize_string(model_output) + for i in range(len(possible_answer)): + if type(possible_answer[i]) == str: + standardize_possible_answer.append(standardize_string(possible_answer[i])) + + if standardize_model_output not in standardize_possible_answer: + return { + "valid": False, + "error": [ + f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive." + ], + "error_type": "value_error:string", + } + + return {"valid": True, "error": []} + + +def list_checker(param: str, model_output: list, possible_answer: list): + # Convert the tuple to a list + + standardize_model_output = list(model_output) + + # If the element in the list is a string, we need to standardize it + for i in range(len(standardize_model_output)): + if type(standardize_model_output[i]) == str: + standardize_model_output[i] = standardize_string(model_output[i]) + + standardize_possible_answer: Any = [] + # We also need to standardize the possible answers + for i in range(len(possible_answer)): + standardize_possible_answer.append([]) + for j in range(len(possible_answer[i])): + if type(possible_answer[i][j]) == str: + standardize_possible_answer[i].append( + standardize_string(possible_answer[i][j]) + ) + else: + standardize_possible_answer[i].append(possible_answer[i][j]) + + if standardize_model_output not in standardize_possible_answer: + return { + "valid": False, + "error": [ + f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}." + ], + "error_type": "value_error:list/tuple", + } + + return {"valid": True, "error": []} + + +def dict_checker(param: str, model_output: dict, possible_answers: list): + # This function works for simple dictionaries, but not dictionaries with nested dictionaries. + # The current dataset only contains simple dictionaries, so this is sufficient. + + result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} + for i in range(len(possible_answers)): + + if possible_answers[i] == "": + continue + + result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} + + flag = True + + possible_answer = possible_answers[i] + # possible_anwer is a single dictionary + + for key, value in model_output.items(): + if key not in possible_answer: + result["valid"] = False + result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined] + result["error_type"] = "value_error:dict_key" + flag = False + break + + standardize_value = value + # If the value is a string, we need to standardize it + if type(value) == str: + standardize_value = standardize_string(value) + + # We also need to standardize the possible answers if they are string + standardize_possible_answer = [] + for i in range(len(possible_answer[key])): + if type(possible_answer[key][i]) == str: + standardize_possible_answer.append( + standardize_string(possible_answer[key][i]) + ) + else: + standardize_possible_answer.append(possible_answer[key][i]) + + if standardize_value not in standardize_possible_answer: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}." + ) + result["error_type"] = "value_error:dict_value" + flag = False + break + + for key, value in possible_answer.items(): + if key not in model_output and "" not in value: + result["valid"] = False + result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined] + result["error_type"] = "value_error:dict_key" + flag = False + break + + if flag: + return {"valid": True, "error": []} + + return result + + +def list_dict_checker(param: str, model_output: list, possible_answers: list): + # This function takes in a list of dictionaries and checks if each dictionary is valid + # The order of the dictionaries in the list must match the order of the possible answers + + result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"} + + for answer_index in range(len(possible_answers)): + flag = True # True means so far, all dictionaries are valid + + # Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers + if len(model_output) != len(possible_answers[answer_index]): + result["valid"] = False + result["error"] = ["Wrong number of dictionaries in the list."] + result["error_type"] = "value_error:list_dict_count" + flag = False + continue + + for dict_index in range(len(model_output)): + result = dict_checker( + param, + model_output[dict_index], + [possible_answers[answer_index][dict_index]], + ) + if not result["valid"]: + flag = False + break + if flag: + return {"valid": True, "error": []} + + return result + + +def simple_function_checker( + func_description: dict, + model_output: dict, + possible_answer: dict, + language: str, + model_name: str, +): + possible_answer = list(possible_answer.values())[0] + # Extract function name and parameters details + func_name = func_description["name"] + param_details = func_description["parameters"]["properties"] + required_params = func_description["parameters"]["required"] + + # Initialize a result dictionary + result = { + "valid": True, + "error": [], + "error_type": "simple_function_checker:unclear", + } + + func_name = convert_func_name(func_name, model_name) + + # Check if function name matches + if func_name not in model_output: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Function name {repr(func_name)} not found in model output." + ) + result["error_type"] = "simple_function_checker:wrong_func_name" + return result + + model_params = model_output[func_name] + + # Check for required parameters in model output + for param in required_params: + if param not in model_params: + result["valid"] = False + result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined] + result["error_type"] = "simple_function_checker:missing_required" + return result + + # Validate types and values for each parameter in model output + for param, value in model_params.items(): + if param not in param_details or param not in possible_answer: + result["valid"] = False + result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined] + result["error_type"] = "simple_function_checker:unexpected_param" + return result + + full_param_details = param_details[param] + expected_type_description = full_param_details["type"] # This is a string + is_variable = False + nested_type_converted = None + + if language == "Java": + from evals.utils.bfcl.java_type_converter import java_type_converter + + expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description] + + if expected_type_description in JAVA_TYPE_CONVERSION: + if type(value) != str: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:java" + return result + + if expected_type_description in NESTED_CONVERSION_TYPE_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = JAVA_TYPE_CONVERSION[nested_type] + value = java_type_converter( + value, expected_type_description, nested_type + ) + else: + value = java_type_converter(value, expected_type_description) + + elif language == "JavaScript": + from evals.utils.bfcl.js_type_converter import js_type_converter + + expected_type_converted = JS_TYPE_CONVERSION[expected_type_description] + + if expected_type_description in JS_TYPE_CONVERSION: + if type(value) != str: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:js" + return result + + if expected_type_description in NESTED_CONVERSION_TYPE_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = JS_TYPE_CONVERSION[nested_type] + value = js_type_converter( + value, expected_type_description, nested_type + ) + else: + value = js_type_converter(value, expected_type_description) + + elif language == "Python": + expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description] + if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = PYTHON_TYPE_MAPPING[nested_type] + + # We convert all tuple value to list when the expected type is tuple. + # The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load(). + # This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future. + if expected_type_description == "tuple" and type(value) == tuple: + value = list(value) + + # Allow python auto conversion from int to float + if ( + language == "Python" + and expected_type_description == "float" + and type(value) == int + ): + value = float(value) + + # Type checking + # In fact, we only check for Python here. + # Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct. + type_check_result = type_checker( + param, + value, + possible_answer[param], + expected_type_description, + expected_type_converted, + nested_type_converted, + ) + is_variable = type_check_result["is_variable"] + if not type_check_result["valid"]: + return type_check_result + + # It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable. + # We can just treat the variable as a string and use the normal flow. + if not is_variable: + # Special handle for dictionaries + if expected_type_converted == dict: + result = dict_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Special handle for list of dictionaries + elif expected_type_converted == list and nested_type_converted == dict: + result = list_dict_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Special handle for strings + elif expected_type_converted == str: + # We don't check for case sensitivity for string, as long as it's not a variable + result = string_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + elif expected_type_converted == list: + result = list_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Check if the value is within the possible answers + if value not in possible_answer[param]: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}." + ) + result["error_type"] = "value_error:others" + return result + + # Check for optional parameters not provided but allowed + for param in possible_answer: + if param not in model_params and "" not in possible_answer[param]: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Optional parameter {repr(param)} not provided and not marked as optional." + ) + result["error_type"] = "simple_function_checker:missing_optional" + return result + + return result + + +def parallel_function_checker_enforce_order( + func_descriptions: list, + model_output: list, + possible_answers: dict, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "parallel_function_checker_enforce_order:wrong_count", + } + + func_name_list = list(possible_answers.keys()) + possible_answers_list = [] + + for key, value in possible_answers.items(): + possible_answers_list.append({key: value}) + + for i in range(len(possible_answers_list)): + func_description = find_description(func_descriptions, func_name_list[i]) + + result = simple_function_checker( + func_description, + model_output[i], + possible_answers_list[i], + language, + model_name, + ) + if not result["valid"]: + return result + + return {"valid": True, "error": []} + + +def parallel_function_checker_no_order( + func_descriptions: list, + model_output: list, + possible_answers: list, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "parallel_function_checker_no_order:wrong_count", + } + + matched_indices = [] + + # We go throught the possible answers one by one, and eliminate the model output that matches the possible answer + # It must be this way because we need ground truth to fetch the correct function description + for i in range(len(possible_answers)): + # possible_answers[i] is a dictionary with only one key + func_name_expected = list(possible_answers[i].keys())[0] + func_description = find_description(func_descriptions, func_name_expected) + + all_errors = [] + + for index in range(len(model_output)): + if index in matched_indices: + continue + + result = simple_function_checker( + func_description, + model_output[index], + possible_answers[i], + language, + model_name, + ) + + if result["valid"]: + matched_indices.append(index) + break + else: + all_errors.append( + { + f"Model Result Index {index}": { + "sub_error": result["error"], + "sub_error_type": result["error_type"], + "model_output_item": model_output[index], + "possible_answer_item": possible_answers[i], + } + } + ) + + if not result["valid"]: + considered_indices = [ + i for i in range(len(model_output)) if i not in matched_indices + ] + all_errors.insert( + 0, + f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] + ) + return { + "valid": False, + "error": all_errors, + "error_type": "parallel_function_checker_no_order:cannot_find_match", + } + + return {"valid": True, "error": []} + + +def multiple_function_checker( + func_descriptions: list, + model_output: list, + possible_answers: list, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "multiple_function_checker:wrong_count", + } + + # possible_answers is a list of only one dictionary with only one key + func_name_expected = list(possible_answers[0].keys())[0] + func_description = find_description(func_descriptions, func_name_expected) + return simple_function_checker( + func_description, + model_output[0], + possible_answers[0], + language, + model_name, + ) + + +def patten_matcher(exec_output, expected_result, function_call, is_sanity_check): + result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + if type(exec_output) != type(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type", + "model_executed_output": exec_output, + } + if type(exec_output) == dict: + # We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one. + # This happens when the key is a timestamp or a random number. + if is_sanity_check: + if len(exec_output) != len(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type:dict_length", + "model_executed_output": exec_output, + } + else: + return result + + for key, value in expected_result.items(): + if key not in exec_output: + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output." + ], + "error_type": "executable_checker:wrong_result_type:dict_key_not_found", + "model_executed_output": exec_output, + } + for key, value in exec_output.items(): + if key not in expected_result: + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output." + ], + "error_type": "executable_checker:wrong_result_type:dict_extra_key", + "model_executed_output": exec_output, + } + if type(exec_output) == list: + if len(exec_output) != len(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type:list_length", + "model_executed_output": exec_output, + } + return result + + +#### Helper functions for Exec #### +def executable_checker_simple( + function_call: str, + expected_result, + expected_result_type: str, + is_sanity_check=False, +): + result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + exec_dict: Any = {} + + try: + exec( + "from executable_python_function import *" + "\nresult=" + function_call, + exec_dict, + ) + exec_output = exec_dict["result"] + except NoAPIKeyError as e: + raise e + except Exception as e: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Error in execution: {repr(function_call)}. Error: {str(e)}" + ) + result["error_type"] = "executable_checker:execution_error" + return result + + # We need to special handle the case where the execution result is a tuple and convert it to a list + # Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json + if isinstance(exec_output, tuple): + exec_output = list(exec_output) + + if expected_result_type == "exact_match": + if exec_output != expected_result: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}." + ) + result["error_type"] = "executable_checker:wrong_result" + result["model_executed_output"] = exec_output + return result + + elif expected_result_type == "real_time_match": + # Allow for 5% difference + if (type(expected_result) == float or type(expected_result) == int) and ( + type(exec_output) == float or type(exec_output) == int + ): + if not ( + expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE) + <= exec_output + <= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE) + ): + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed." + ) + result["error_type"] = "executable_checker:wrong_result_real_time" + result["model_executed_output"] = exec_output + return result + else: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria." + ) + result["error_type"] = "executable_checker:wrong_result_real_time" + result["model_executed_output"] = exec_output + return result + + else: + # structural match + pattern_match_result = patten_matcher( + exec_output, expected_result, function_call, is_sanity_check + ) + if not pattern_match_result["valid"]: + return pattern_match_result + + return result + + +def executable_checker_parallel_no_order( + decoded_result: list, expected_exec_result: list, expected_exec_result_type: list +): + + if len(decoded_result) != len(expected_exec_result): + return { + "valid": False, + "error": [ + f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}." + ], + "error_type": "value_error:exec_result_count", + } + + matched_indices = [] + for i in range(len(expected_exec_result)): + all_errors = [] + for index in range(len(decoded_result)): + if index in matched_indices: + continue + + result = executable_checker_simple( + decoded_result[index], + expected_exec_result[i], + expected_exec_result_type[i], + False, + ) + + if result["valid"]: + matched_indices.append(index) + break + else: + all_errors.append( + { + f"Model Result Index {index}": { + "sub_error": result["error"], + "sub_error_type": result["error_type"], + "model_executed_output": ( + result["model_executed_output"] + if "model_executed_output" in result + else None + ), + } + } + ) + + if not result["valid"]: + considered_indices = [ + i for i in range(len(decoded_result)) if i not in matched_indices + ] + all_errors.insert( + 0, + f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] + ) + return { + "valid": False, + "error": all_errors, + "error_type": "executable_checker:cannot_find_match", + } + + return {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + +#### Main function #### +def executable_checker_rest(func_call, idx): + # Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used. + EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution + with open(EVAL_GROUND_TRUTH_PATH, "r") as f: + EVAL_GROUND_TRUTH = f.readlines() + if "https://geocode.maps.co" in func_call: + time.sleep(2) + if "requests_get" in func_call: + func_call = func_call.replace("requests_get", "requests.get") + try: + response = eval(func_call) + except Exception as e: + return { + "valid": False, + "error": [f"Execution failed. {str(e)}"], + "error_type": "executable_checker_rest:execution_error", + } + + try: + if response.status_code == 200: + + eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx]) + try: + if isinstance(eval_GT_json, dict): + if isinstance(response.json(), dict): + if set(eval_GT_json.keys()) == set(response.json().keys()): + return {"valid": True, "error": [], "error_type": ""} + return { + "valid": False, + "error": ["Key inconsistency"], + "error_type": "executable_checker_rest:wrong_key", + } + return { + "valid": False, + "error": [ + f"Expected dictionary, but got {type(response.json())}" + ], + "error_type": "executable_checker_rest:wrong_type", + } + + elif isinstance(eval_GT_json, list): + if isinstance(response.json(), list): + if len(eval_GT_json) != len(response.json()): + return { + "valid": False, + "error": [f"Response list length inconsistency."], + "error_type": "value_error:exec_result_rest_count", + } + + else: + for i in range(len(eval_GT_json)): + if set(eval_GT_json[i].keys()) != set( + response.json()[i].keys() + ): + return { + "valid": False, + "error": [f"Key inconsistency"], + "error_type": "executable_checker_rest:wrong_key", + } + + return {"valid": True, "error": []} + else: + return { + "valid": False, + "error": [ + f"Expected list, but got {type(response.json())}" + ], + "error_type": "executable_checker_rest:wrong_type", + } + return { + "valid": False, + "error": [ + f"Expected dict or list, but got {type(response.json())}" + ], + "error_type": "executable_checker_rest:wrong_type", + } + except Exception as e: + return { + "valid": False, + "error": [ + f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}" + ], + "error_type": "executable_checker_rest:response_format_error", + } + else: + return { + "valid": False, + "error": [ + f"Execution result status code is not 200, got {response.status_code}" + ], + "error_type": "executable_checker_rest:wrong_status_code", + } + except Exception as e: + return { + "valid": False, + "error": [f"Cannot get status code of the response. Error: {str(e)}"], + "error_type": "executable_checker_rest:cannot_get_status_code", + } + + +def ast_checker( + func_description, model_output, possible_answer, language, test_category, model_name +): + if "parallel" in test_category: + return parallel_function_checker_no_order( + func_description, model_output, possible_answer, language, model_name + ) + + elif "multiple" in test_category: + return multiple_function_checker( + func_description, model_output, possible_answer, language, model_name + ) + + else: + if len(model_output) != 1: + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "simple_function_checker:wrong_count", + } + + return simple_function_checker( + func_description[0], + model_output[0], + possible_answer[0], + language, + model_name, + ) + + +def exec_checker(decoded_result: list, func_description: dict, test_category: str): + if "multiple" in test_category or "parallel" in test_category: + return executable_checker_parallel_no_order( + decoded_result, + func_description["execution_result"], + func_description["execution_result_type"], + ) + + else: + if len(decoded_result) != 1: + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "simple_exec_checker:wrong_count", + } + return executable_checker_simple( + decoded_result[0], + func_description["execution_result"][0], + func_description["execution_result_type"][0], + False, + ) + + +def is_empty_output(decoded_output): + # This function is a patch to the ast decoder for relevance detection + # Sometimes the ast decoder will parse successfully, but the input doens't really have a function call + # [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct) + if not is_function_calling_format_output(decoded_output): + return True + if len(decoded_output) == 0: + return True + if len(decoded_output) == 1 and len(decoded_output[0]) == 0: + return True + + +def is_function_calling_format_output(decoded_output): + # Ensure the output is a list of dictionaries + if type(decoded_output) == list: + for item in decoded_output: + if type(item) != dict: + return False + return True + return False diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/tree_sitter.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/tree_sitter.py new file mode 100644 index 000000000..607287f29 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl/tree_sitter.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +""" +Tree-sitter changes its API with unfortunate frequency. Modules that need it should +import it from here so that we can centrally manage things as necessary. +""" + +# These currently work with tree-sitter 0.23.0 +# NOTE: Don't import tree-sitter or any of the language modules in the main module +# because not all environments have them. Import lazily inside functions where needed. + +import importlib +import typing + +if typing.TYPE_CHECKING: + import tree_sitter + + +def get_language(language: str) -> "tree_sitter.Language": + import tree_sitter + + language_module_name = f"tree_sitter_{language}" + try: + language_module = importlib.import_module(language_module_name) + except ModuleNotFoundError as exc: + raise ValueError( + f"Language {language} is not found. Please install the tree-sitter-{language} package." + ) from exc + return tree_sitter.Language(language_module.language()) + + +def get_parser(language: str, **kwargs) -> "tree_sitter.Parser": + import tree_sitter + + lang = get_language(language) + return tree_sitter.Parser(lang, **kwargs) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py new file mode 100644 index 000000000..34bbc0157 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn + +from .fn_defs.bfcl import bfcl +from .bfcl.ast_parser import decode_ast +from .bfcl.checker import ast_checker, is_empty_output +import json +import re + + + +def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: + contain_func_call = False + error = None + error_type = None + checker_result = {} + try: + prediction = decode_ast(x["generated_answer"], x["language"]) or "" + contain_func_call = True + # if not is_function_calling_format_output(prediction): + if is_empty_output(prediction): + contain_func_call = False + error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability." + error_type = "ast_decoder:decoder_wrong_output_format" + else: + checker_result = ast_checker( + json.loads(x["function"]), + prediction, + json.loads(x["ground_truth"]), + x["language"], + test_category=test_category, + model_name="", + ) + except Exception as e: + prediction = "" + error = f"Invalid syntax. Failed to decode AST. {str(e)}" + error_type = "ast_decoder:decoder_failed" + return { + "prediction": prediction, + "contain_func_call": contain_func_call, + "valid": checker_result.get("valid", False), + "error": error or checker_result.get("error", ""), + "error_type": error_type or checker_result.get("error_type", ""), + } + + +def gen_valid(x: Dict[str, Any]) -> Dict[str, float]: + return {"valid": x["valid"]} + + +def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]: + # This function serves for both relevance and irrelevance tests, which share the exact opposite logic. + # If `test_category` is "irrelevance", the model is expected to output no function call. + # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). + # If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call. + acc = ( + not x["contain_func_call"] + if "irrelevance" in x["id"] + else x["contain_func_call"] + ) + return {"valid": float(acc)} + + +class BFCLScoringFn(RegisteredBaseScoringFn): + """ + A scoring_fn for BFCL + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + bfcl.identifier: bfcl, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "bfcl", + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + test_category = re.sub(r'_[0-9_-]+$', '', input_row['id']) + score_result = postprocess(input_row, test_category) + if (test_category in set(['irrelevance', 'live_relevance', 'live_irrelevance'])): + score = gen_relevance_acc(score_result)['valid'] + else: + score = gen_valid(score_result)['valid'] + return { + "score": float(score), + } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py new file mode 100644 index 000000000..392d92c86 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +bfcl = ScoringFn( + identifier="basic::bfcl", + description="BFCL complex scoring", + return_type=NumberType(), + provider_id="basic", + provider_resource_id="bfcl", + params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), +) diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index c9531f417..a19c3ffe2 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -118,9 +118,30 @@ models: model_type: embedding shields: [] vector_dbs: [] -datasets: [] +datasets: + - dataset_id: bfcl + provider_id: huggingface + url: + uri: https://huggingface.co/datasets/llamastack/bfcl_v3/tree/main + metadata: + path: llamastack/bfcl_v3 + split: train + dataset_schema: + function: + type: string + language: + type: string + ground_truth: + type: string + id: + type: string + chat_completion_input: + type: string scoring_fns: [] -benchmarks: [] +benchmarks: + - benchmark_id: bfcl + dataset_id: bfcl + scoring_functions: ["basic::bfcl"] tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search