diff --git a/distributions/dependencies.json b/distributions/dependencies.json index c3f039247..d2ed12d3a 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -30,6 +30,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "cerebras": [ @@ -62,6 +63,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -97,6 +99,7 @@ "sqlite-vec", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -132,6 +135,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -168,6 +172,7 @@ "sqlite-vec", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -203,6 +208,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -236,6 +242,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "hf-endpoint": [ @@ -270,6 +277,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "hf-serverless": [ @@ -304,6 +312,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -344,6 +353,7 @@ "torchvision", "tqdm", "transformers", + "tree_sitter", "uvicorn", "zmq" ], @@ -385,6 +395,7 @@ "torchvision", "tqdm", "transformers", + "tree_sitter", "uvicorn", "zmq" ], @@ -417,6 +428,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "ollama": [ @@ -451,6 +463,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "open-benchmark": [ @@ -485,6 +498,7 @@ "together", "tqdm", "transformers", + "tree_sitter", "uvicorn" ], "passthrough": [ @@ -517,6 +531,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -551,6 +566,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -616,6 +632,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -651,6 +668,7 @@ "together", "tqdm", "transformers", + "tree_sitter", "uvicorn", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" @@ -685,6 +703,7 @@ "sentencepiece", "tqdm", "transformers", + "tree_sitter", "uvicorn", "vllm", "sentence-transformers --no-deps", diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index a1bebaa4c..85b351262 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.inference import Inference, UserMessage +from llama_stack.apis.inference import Inference, SystemMessage, UserMessage from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import BenchmarksProtocolPrivate @@ -118,7 +118,7 @@ class MetaReferenceEvalImpl( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages] + input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] # NOTE: only single-turn agent generation is supported. Create a new session for each input row session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") @@ -168,10 +168,11 @@ class MetaReferenceEvalImpl( generations.append({ColumnName.generated_answer.value: response.completion_message.content}) elif ColumnName.chat_completion_input.value in x: chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json] + input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] messages = [] if candidate.system_message: messages.append(candidate.system_message) + messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 00945b99d..599f5f98c 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -22,12 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import ( ) from .config import BasicScoringConfig +from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn -FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn] +FIXED_FNS = [ + EqualityScoringFn, + SubsetOfScoringFn, + RegexParserScoringFn, + RegexParserMathResponseScoringFn, + BFCLScoringFn, +] class BasicScoringImpl( diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py new file mode 100644 index 000000000..f37780f3e --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import re +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn + +from ..utils.bfcl.ast_parser import decode_ast +from ..utils.bfcl.checker import ast_checker, is_empty_output +from .fn_defs.bfcl import bfcl + + +def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: + contain_func_call = False + error = None + error_type = None + checker_result = {} + try: + prediction = decode_ast(x["generated_answer"], x["language"]) or "" + contain_func_call = True + # if not is_function_calling_format_output(prediction): + if is_empty_output(prediction): + contain_func_call = False + error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability." + error_type = "ast_decoder:decoder_wrong_output_format" + else: + checker_result = ast_checker( + json.loads(x["function"]), + prediction, + json.loads(x["ground_truth"]), + x["language"], + test_category=test_category, + model_name="", + ) + except Exception as e: + prediction = "" + error = f"Invalid syntax. Failed to decode AST. {str(e)}" + error_type = "ast_decoder:decoder_failed" + return { + "prediction": prediction, + "contain_func_call": contain_func_call, + "valid": checker_result.get("valid", False), + "error": error or checker_result.get("error", ""), + "error_type": error_type or checker_result.get("error_type", ""), + } + + +def gen_valid(x: Dict[str, Any]) -> Dict[str, float]: + return {"valid": x["valid"]} + + +def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]: + # This function serves for both relevance and irrelevance tests, which share the exact opposite logic. + # If `test_category` is "irrelevance", the model is expected to output no function call. + # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). + # If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call. + acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"] + return {"valid": float(acc)} + + +class BFCLScoringFn(RegisteredBaseScoringFn): + """ + A scoring_fn for BFCL + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + bfcl.identifier: bfcl, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "bfcl", + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"]) + score_result = postprocess(input_row, test_category) + if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}: + score = gen_relevance_acc(score_result)["valid"] + else: + score = gen_valid(score_result)["valid"] + return { + "score": float(score), + } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py new file mode 100644 index 000000000..392d92c86 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +bfcl = ScoringFn( + identifier="basic::bfcl", + description="BFCL complex scoring", + return_type=NumberType(), + provider_id="basic", + provider_resource_id="bfcl", + params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), +) diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py new file mode 100644 index 000000000..445cdfc77 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py @@ -0,0 +1,296 @@ +# ruff: noqa +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import ast + +from .tree_sitter import get_parser + + +def parse_java_function_call(source_code): + if not source_code.endswith(";"): + source_code += ";" # Necessary for the parser not to register an error + parser = get_parser("java") + tree = parser.parse(bytes(source_code, "utf8")) + root_node = tree.root_node + + if root_node.has_error: + raise Exception("Error parsing java the source code.") + + def get_text(node): + """Returns the text represented by the node.""" + return source_code[node.start_byte : node.end_byte] + + def traverse_node(node, nested=False): + if node.type == "string_literal": + if nested: + return get_text(node) + # Strip surrounding quotes from string literals + return get_text(node)[1:-1] + elif node.type == "character_literal": + if nested: + return get_text(node) + # Strip surrounding single quotes from character literals + return get_text(node)[1:-1] + """Traverse the node to collect texts for complex structures.""" + if node.type in [ + "identifier", + "class_literal", + "type_identifier", + "method_invocation", + ]: + return get_text(node) + elif node.type == "array_creation_expression": + # Handle array creation expression specifically + type_node = node.child_by_field_name("type") + value_node = node.child_by_field_name("value") + type_text = traverse_node(type_node, True) + value_text = traverse_node(value_node, True) + return f"new {type_text}[]{value_text}" + elif node.type == "object_creation_expression": + # Handle object creation expression specifically + type_node = node.child_by_field_name("type") + arguments_node = node.child_by_field_name("arguments") + type_text = traverse_node(type_node, True) + if arguments_node: + # Process each argument carefully, avoiding unnecessary punctuation + argument_texts = [] + for child in arguments_node.children: + if child.type not in [ + ",", + "(", + ")", + ]: # Exclude commas and parentheses + argument_text = traverse_node(child, True) + argument_texts.append(argument_text) + arguments_text = ", ".join(argument_texts) + return f"new {type_text}({arguments_text})" + else: + return f"new {type_text}()" + elif node.type == "set": + # Handling sets specifically + items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]] + return "{" + ", ".join(items) + "}" + + elif node.child_count > 0: + return "".join(traverse_node(child, True) for child in node.children) + else: + return get_text(node) + + def extract_arguments(args_node): + arguments = {} + for child in args_node.children: + if child.type == "assignment_expression": + # For named parameters + name_node, value_node = child.children[0], child.children[2] + name = get_text(name_node) + value = traverse_node(value_node) + if name in arguments: + if not isinstance(arguments[name], list): + arguments[name] = [arguments[name]] + arguments[name].append(value) + else: + arguments[name] = value + # arguments.append({'name': name, 'value': value}) + elif child.type in ["identifier", "class_literal", "set"]: + # For unnamed parameters and handling sets + value = traverse_node(child) + if None in arguments: + if not isinstance(arguments[None], list): + arguments[None] = [arguments[None]] + arguments[None].append(value) + else: + arguments[None] = value + return arguments + + def traverse(node): + if node.type == "method_invocation": + # Extract the function name and its arguments + method_name = get_text(node.child_by_field_name("name")) + class_name_node = node.child_by_field_name("object") + if class_name_node: + class_name = get_text(class_name_node) + function_name = f"{class_name}.{method_name}" + else: + function_name = method_name + arguments_node = node.child_by_field_name("arguments") + if arguments_node: + arguments = extract_arguments(arguments_node) + for key, value in arguments.items(): + if isinstance(value, list): + raise Exception("Error: Multiple arguments with the same name are not supported.") + return [{function_name: arguments}] + + else: + for child in node.children: + result = traverse(child) + if result: + return result + + result = traverse(root_node) + return result if result else {} + + +def parse_javascript_function_call(source_code): + if not source_code.endswith(";"): + source_code += ";" # Necessary for the parser not to register an error + parser = get_parser("javascript") + # Parse the source code + tree = parser.parse(bytes(source_code, "utf8")) + root_node = tree.root_node + if root_node.has_error: + raise Exception("Error js parsing the source code.") + + # Function to recursively extract argument details + def extract_arguments(node): + args = {} + for child in node.children: + if child.type == "assignment_expression": + # Extract left (name) and right (value) parts of the assignment + name = child.children[0].text.decode("utf-8") + value = child.children[2].text.decode("utf-8") + if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): + value = value[1:-1] # Trim the quotation marks + if name in args: + if not isinstance(args[name], list): + args[name] = [args[name]] + args[name].append(value) + else: + args[name] = value + + elif child.type == "identifier" or child.type == "true": + # Handle non-named arguments and boolean values + value = child.text.decode("utf-8") + if None in args: + if not isinstance(args[None], list): + args[None] = [args[None]] + args[None].append(value) + else: + args[None] = value + return args + + # Find the function call and extract its name and arguments + if root_node.type == "program": + for child in root_node.children: + if child.type == "expression_statement": + for sub_child in child.children: + if sub_child.type == "call_expression": + function_name = sub_child.children[0].text.decode("utf8") + arguments_node = sub_child.children[1] + parameters = extract_arguments(arguments_node) + for key, value in parameters.items(): + if isinstance(value, list): + raise Exception("Error: Multiple arguments with the same name are not supported.") + result = [{function_name: parameters}] + return result + + +def ast_parse(input_str, language="Python"): + if language == "Python": + cleaned_input = input_str.strip("[]'") + parsed = ast.parse(cleaned_input, mode="eval") + extracted = [] + if isinstance(parsed.body, ast.Call): + extracted.append(resolve_ast_call(parsed.body)) + else: + for elem in parsed.body.elts: + extracted.append(resolve_ast_call(elem)) + return extracted + elif language == "Java": + return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string + elif language == "JavaScript": + return parse_javascript_function_call(input_str[1:-1]) + else: + raise NotImplementedError(f"Unsupported language: {language}") + + +def resolve_ast_call(elem): + # Handle nested attributes for deeply nested module paths + func_parts = [] + func_part = elem.func + while isinstance(func_part, ast.Attribute): + func_parts.append(func_part.attr) + func_part = func_part.value + if isinstance(func_part, ast.Name): + func_parts.append(func_part.id) + func_name = ".".join(reversed(func_parts)) + args_dict = {} + # Parse when args are simply passed as an unnamed dictionary arg + for arg in elem.args: + if isinstance(arg, ast.Dict): + for key, value in zip(arg.keys, arg.values): + if isinstance(key, ast.Constant): + arg_name = key.value + output = resolve_ast_by_type(value) + args_dict[arg_name] = output + for arg in elem.keywords: + output = resolve_ast_by_type(arg.value) + args_dict[arg.arg] = output + return {func_name: args_dict} + + +def resolve_ast_by_type(value): + if isinstance(value, ast.Constant): + if value.value is Ellipsis: + output = "..." + else: + output = value.value + elif isinstance(value, ast.UnaryOp): + output = -value.operand.value + elif isinstance(value, ast.List): + output = [resolve_ast_by_type(v) for v in value.elts] + elif isinstance(value, ast.Dict): + output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)} + elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values + output = value.value + elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments + output = eval(ast.unparse(value)) + elif isinstance(value, ast.Name): + output = value.id + elif isinstance(value, ast.Call): + if len(value.keywords) == 0: + output = ast.unparse(value) + else: + output = resolve_ast_call(value) + elif isinstance(value, ast.Tuple): + output = tuple(resolve_ast_by_type(v) for v in value.elts) + elif isinstance(value, ast.Lambda): + output = eval(ast.unparse(value.body[0].value)) + elif isinstance(value, ast.Ellipsis): + output = "..." + elif isinstance(value, ast.Subscript): + try: + output = ast.unparse(value.body[0].value) + except: + output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]" + else: + raise Exception(f"Unsupported AST type: {type(value)}") + return output + + +def decode_ast(result, language="Python"): + func = result + func = func.replace("\n", "") # remove new line characters + if not func.startswith("["): + func = "[" + func + if not func.endswith("]"): + func = func + "]" + decoded_output = ast_parse(func, language) + return decoded_output + + +def decode_execute(result): + func = result + func = func.replace("\n", "") # remove new line characters + if not func.startswith("["): + func = "[" + func + if not func.endswith("]"): + func = func + "]" + decode_output = ast_parse(func) + execution_list = [] + for function_call in decode_output: + for key, value in function_call.items(): + execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})") + return execution_list diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py new file mode 100644 index 000000000..f6aab123c --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py @@ -0,0 +1,989 @@ +# ruff: noqa +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import json +import re +import time +from typing import Any + +# Comment out for now until we actually use the rest checker in evals +# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function. + + +class NoAPIKeyError(Exception): + def __init__(self): + self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate." + super().__init__(self.message) + + +REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2 + + +JAVA_TYPE_CONVERSION = { + "byte": int, + "short": int, + "integer": int, + "float": float, + "double": float, + "long": int, + "boolean": bool, + "char": str, + "Array": list, + "ArrayList": list, + "Set": set, + "HashMap": dict, + "Hashtable": dict, + "Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list + "Stack": list, + "String": str, + "any": str, +} + +JS_TYPE_CONVERSION = { + "String": str, + "integer": int, + "float": float, + "Bigint": int, + "Boolean": bool, + "dict": dict, + "array": list, + "any": str, +} + +# We switch to conditional import for the following two imports to avoid unnecessary installations. +# User doesn't need to setup the tree-sitter packages if they are not running the test for that language. +# from js_type_converter import js_type_converter +# from java_type_converter import java_type_converter + +PYTHON_TYPE_MAPPING = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "array": list, + "tuple": list, + "dict": dict, + "any": str, +} + +# This is the list of types that we need to recursively check its values +PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"] + + +NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"] + + +#### Helper functions for AST #### +def find_description(func_descriptions, name): + if type(func_descriptions) == list: + for func_description in func_descriptions: + if func_description["name"] == name: + return func_description + return None + else: + # it is a dict, there is only one function + return func_descriptions + + +def get_possible_answer_type(possible_answer: list): + for answer in possible_answer: + if answer != "": # Optional parameter + return type(answer) + return None + + +def type_checker( + param: str, + value, + possible_answer: list, + expected_type_description: str, + expected_type_converted, + nested_type_converted, +): + # NOTE: This type checker only supports nested type checking for one level deep. + # We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex. + + result: Any = { + "valid": True, + "error": [], + "is_variable": False, + "error_type": "type_error:simple", + } + + is_variable = False + # check for the case where a variable is used instead of a actual value. + # use the type in possible_answer as the expected type + possible_answer_type = get_possible_answer_type(possible_answer) + # if possible_answer only contains optional parameters, we can't determine the type + if possible_answer_type != None: + # we are being precise here. + # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer + if possible_answer_type != expected_type_converted: + is_variable = True + + # value is the same type as in function description + if type(value) == expected_type_converted: + # We don't need to do recursive check for simple types + if nested_type_converted == None: + result["is_variable"] = is_variable + return result + else: + for possible_answer_item in possible_answer: + flag = True # Each parameter should match to at least one possible answer type. + # Here, we assume that each item should be the same type. We could also relax it. + if type(possible_answer_item) == list: + for value_item in value: + checker_result = type_checker( + param, + value_item, + possible_answer_item, + str(nested_type_converted), + nested_type_converted, + None, + ) + if not checker_result["valid"]: + flag = False + break + + if flag: + return {"valid": True, "error": [], "is_variable": is_variable} + + result["valid"] = False + result["error"] = [ + f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}." + ] + result["error_type"] = "type_error:nested" + + # value is not as expected, check for the case where a variable is used instead of a actual value + # use the type in possible_answer as the expected type + possible_answer_type = get_possible_answer_type(possible_answer) + # if possible_answer only contains optional parameters, we can't determine the type + if possible_answer_type != None: + # we are being precise here. + # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer + if type(value) == possible_answer_type: + result["is_variable"] = True + return result + + result["valid"] = False + result["error"].append( + f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:simple" + return result + + +def standardize_string(input_string: str): + # This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase + # It will also convert all the single quotes to double quotes + # This is used to compare the model output with the possible answers + # We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024 + regex_string = r"[ \,\.\/\-\_\*\^]" + return re.sub(regex_string, "", input_string).lower().replace("'", '"') + + +def string_checker(param: str, model_output: str, possible_answer: list): + standardize_possible_answer = [] + standardize_model_output = standardize_string(model_output) + for i in range(len(possible_answer)): + if type(possible_answer[i]) == str: + standardize_possible_answer.append(standardize_string(possible_answer[i])) + + if standardize_model_output not in standardize_possible_answer: + return { + "valid": False, + "error": [ + f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive." + ], + "error_type": "value_error:string", + } + + return {"valid": True, "error": []} + + +def list_checker(param: str, model_output: list, possible_answer: list): + # Convert the tuple to a list + + standardize_model_output = list(model_output) + + # If the element in the list is a string, we need to standardize it + for i in range(len(standardize_model_output)): + if type(standardize_model_output[i]) == str: + standardize_model_output[i] = standardize_string(model_output[i]) + + standardize_possible_answer: Any = [] + # We also need to standardize the possible answers + for i in range(len(possible_answer)): + standardize_possible_answer.append([]) + for j in range(len(possible_answer[i])): + if type(possible_answer[i][j]) == str: + standardize_possible_answer[i].append(standardize_string(possible_answer[i][j])) + else: + standardize_possible_answer[i].append(possible_answer[i][j]) + + if standardize_model_output not in standardize_possible_answer: + return { + "valid": False, + "error": [ + f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}." + ], + "error_type": "value_error:list/tuple", + } + + return {"valid": True, "error": []} + + +def dict_checker(param: str, model_output: dict, possible_answers: list): + # This function works for simple dictionaries, but not dictionaries with nested dictionaries. + # The current dataset only contains simple dictionaries, so this is sufficient. + + result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} + for i in range(len(possible_answers)): + if possible_answers[i] == "": + continue + + result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} + + flag = True + + possible_answer = possible_answers[i] + # possible_anwer is a single dictionary + + for key, value in model_output.items(): + if key not in possible_answer: + result["valid"] = False + result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined] + result["error_type"] = "value_error:dict_key" + flag = False + break + + standardize_value = value + # If the value is a string, we need to standardize it + if type(value) == str: + standardize_value = standardize_string(value) + + # We also need to standardize the possible answers if they are string + standardize_possible_answer = [] + for i in range(len(possible_answer[key])): + if type(possible_answer[key][i]) == str: + standardize_possible_answer.append(standardize_string(possible_answer[key][i])) + else: + standardize_possible_answer.append(possible_answer[key][i]) + + if standardize_value not in standardize_possible_answer: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}." + ) + result["error_type"] = "value_error:dict_value" + flag = False + break + + for key, value in possible_answer.items(): + if key not in model_output and "" not in value: + result["valid"] = False + result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined] + result["error_type"] = "value_error:dict_key" + flag = False + break + + if flag: + return {"valid": True, "error": []} + + return result + + +def list_dict_checker(param: str, model_output: list, possible_answers: list): + # This function takes in a list of dictionaries and checks if each dictionary is valid + # The order of the dictionaries in the list must match the order of the possible answers + + result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"} + + for answer_index in range(len(possible_answers)): + flag = True # True means so far, all dictionaries are valid + + # Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers + if len(model_output) != len(possible_answers[answer_index]): + result["valid"] = False + result["error"] = ["Wrong number of dictionaries in the list."] + result["error_type"] = "value_error:list_dict_count" + flag = False + continue + + for dict_index in range(len(model_output)): + result = dict_checker( + param, + model_output[dict_index], + [possible_answers[answer_index][dict_index]], + ) + if not result["valid"]: + flag = False + break + if flag: + return {"valid": True, "error": []} + + return result + + +def simple_function_checker( + func_description: dict, + model_output: dict, + possible_answer: dict, + language: str, + model_name: str, +): + possible_answer = list(possible_answer.values())[0] + # Extract function name and parameters details + func_name = func_description["name"] + param_details = func_description["parameters"]["properties"] + required_params = func_description["parameters"]["required"] + + # Initialize a result dictionary + result = { + "valid": True, + "error": [], + "error_type": "simple_function_checker:unclear", + } + + # Check if function name matches + if func_name not in model_output: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Function name {repr(func_name)} not found in model output." + ) + result["error_type"] = "simple_function_checker:wrong_func_name" + return result + + model_params = model_output[func_name] + + # Check for required parameters in model output + for param in required_params: + if param not in model_params: + result["valid"] = False + result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined] + result["error_type"] = "simple_function_checker:missing_required" + return result + + # Validate types and values for each parameter in model output + for param, value in model_params.items(): + if param not in param_details or param not in possible_answer: + result["valid"] = False + result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined] + result["error_type"] = "simple_function_checker:unexpected_param" + return result + + full_param_details = param_details[param] + expected_type_description = full_param_details["type"] # This is a string + is_variable = False + nested_type_converted = None + + if language == "Java": + from evals.utils.bfcl.java_type_converter import java_type_converter + + expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description] + + if expected_type_description in JAVA_TYPE_CONVERSION: + if type(value) != str: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:java" + return result + + if expected_type_description in NESTED_CONVERSION_TYPE_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = JAVA_TYPE_CONVERSION[nested_type] + value = java_type_converter(value, expected_type_description, nested_type) + else: + value = java_type_converter(value, expected_type_description) + + elif language == "JavaScript": + from evals.utils.bfcl.js_type_converter import js_type_converter + + expected_type_converted = JS_TYPE_CONVERSION[expected_type_description] + + if expected_type_description in JS_TYPE_CONVERSION: + if type(value) != str: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." + ) + result["error_type"] = "type_error:js" + return result + + if expected_type_description in NESTED_CONVERSION_TYPE_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = JS_TYPE_CONVERSION[nested_type] + value = js_type_converter(value, expected_type_description, nested_type) + else: + value = js_type_converter(value, expected_type_description) + + elif language == "Python": + expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description] + if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST: + nested_type = param_details[param]["items"]["type"] + nested_type_converted = PYTHON_TYPE_MAPPING[nested_type] + + # We convert all tuple value to list when the expected type is tuple. + # The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load(). + # This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future. + if expected_type_description == "tuple" and type(value) == tuple: + value = list(value) + + # Allow python auto conversion from int to float + if language == "Python" and expected_type_description == "float" and type(value) == int: + value = float(value) + + # Type checking + # In fact, we only check for Python here. + # Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct. + type_check_result = type_checker( + param, + value, + possible_answer[param], + expected_type_description, + expected_type_converted, + nested_type_converted, + ) + is_variable = type_check_result["is_variable"] + if not type_check_result["valid"]: + return type_check_result + + # It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable. + # We can just treat the variable as a string and use the normal flow. + if not is_variable: + # Special handle for dictionaries + if expected_type_converted == dict: + result = dict_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Special handle for list of dictionaries + elif expected_type_converted == list and nested_type_converted == dict: + result = list_dict_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Special handle for strings + elif expected_type_converted == str: + # We don't check for case sensitivity for string, as long as it's not a variable + result = string_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + elif expected_type_converted == list: + result = list_checker(param, value, possible_answer[param]) + if not result["valid"]: + return result + continue + + # Check if the value is within the possible answers + if value not in possible_answer[param]: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}." + ) + result["error_type"] = "value_error:others" + return result + + # Check for optional parameters not provided but allowed + for param in possible_answer: + if param not in model_params and "" not in possible_answer[param]: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Optional parameter {repr(param)} not provided and not marked as optional." + ) + result["error_type"] = "simple_function_checker:missing_optional" + return result + + return result + + +def parallel_function_checker_enforce_order( + func_descriptions: list, + model_output: list, + possible_answers: dict, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "parallel_function_checker_enforce_order:wrong_count", + } + + func_name_list = list(possible_answers.keys()) + possible_answers_list = [] + + for key, value in possible_answers.items(): + possible_answers_list.append({key: value}) + + for i in range(len(possible_answers_list)): + func_description = find_description(func_descriptions, func_name_list[i]) + + result = simple_function_checker( + func_description, + model_output[i], + possible_answers_list[i], + language, + model_name, + ) + if not result["valid"]: + return result + + return {"valid": True, "error": []} + + +def parallel_function_checker_no_order( + func_descriptions: list, + model_output: list, + possible_answers: list, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "parallel_function_checker_no_order:wrong_count", + } + + matched_indices = [] + + # We go throught the possible answers one by one, and eliminate the model output that matches the possible answer + # It must be this way because we need ground truth to fetch the correct function description + for i in range(len(possible_answers)): + # possible_answers[i] is a dictionary with only one key + func_name_expected = list(possible_answers[i].keys())[0] + func_description = find_description(func_descriptions, func_name_expected) + + all_errors = [] + + for index in range(len(model_output)): + if index in matched_indices: + continue + + result = simple_function_checker( + func_description, + model_output[index], + possible_answers[i], + language, + model_name, + ) + + if result["valid"]: + matched_indices.append(index) + break + else: + all_errors.append( + { + f"Model Result Index {index}": { + "sub_error": result["error"], + "sub_error_type": result["error_type"], + "model_output_item": model_output[index], + "possible_answer_item": possible_answers[i], + } + } + ) + + if not result["valid"]: + considered_indices = [i for i in range(len(model_output)) if i not in matched_indices] + all_errors.insert( + 0, + f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] + ) + return { + "valid": False, + "error": all_errors, + "error_type": "parallel_function_checker_no_order:cannot_find_match", + } + + return {"valid": True, "error": []} + + +def multiple_function_checker( + func_descriptions: list, + model_output: list, + possible_answers: list, + language: str, + model_name: str, +): + if len(model_output) != len(possible_answers): + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "multiple_function_checker:wrong_count", + } + + # possible_answers is a list of only one dictionary with only one key + func_name_expected = list(possible_answers[0].keys())[0] + func_description = find_description(func_descriptions, func_name_expected) + return simple_function_checker( + func_description, + model_output[0], + possible_answers[0], + language, + model_name, + ) + + +def patten_matcher(exec_output, expected_result, function_call, is_sanity_check): + result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + if type(exec_output) != type(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type", + "model_executed_output": exec_output, + } + if type(exec_output) == dict: + # We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one. + # This happens when the key is a timestamp or a random number. + if is_sanity_check: + if len(exec_output) != len(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type:dict_length", + "model_executed_output": exec_output, + } + else: + return result + + for key, value in expected_result.items(): + if key not in exec_output: + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output." + ], + "error_type": "executable_checker:wrong_result_type:dict_key_not_found", + "model_executed_output": exec_output, + } + for key, value in exec_output.items(): + if key not in expected_result: + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output." + ], + "error_type": "executable_checker:wrong_result_type:dict_extra_key", + "model_executed_output": exec_output, + } + if type(exec_output) == list: + if len(exec_output) != len(expected_result): + return { + "valid": False, + "error": [ + f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." + ], + "error_type": "executable_checker:wrong_result_type:list_length", + "model_executed_output": exec_output, + } + return result + + +#### Helper functions for Exec #### +def executable_checker_simple( + function_call: str, + expected_result, + expected_result_type: str, + is_sanity_check=False, +): + result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + exec_dict: Any = {} + + try: + exec( + "from executable_python_function import *" + "\nresult=" + function_call, + exec_dict, + ) + exec_output = exec_dict["result"] + except NoAPIKeyError as e: + raise e + except Exception as e: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Error in execution: {repr(function_call)}. Error: {str(e)}" + ) + result["error_type"] = "executable_checker:execution_error" + return result + + # We need to special handle the case where the execution result is a tuple and convert it to a list + # Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json + if isinstance(exec_output, tuple): + exec_output = list(exec_output) + + if expected_result_type == "exact_match": + if exec_output != expected_result: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}." + ) + result["error_type"] = "executable_checker:wrong_result" + result["model_executed_output"] = exec_output + return result + + elif expected_result_type == "real_time_match": + # Allow for 5% difference + if (type(expected_result) == float or type(expected_result) == int) and ( + type(exec_output) == float or type(exec_output) == int + ): + if not ( + expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE) + <= exec_output + <= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE) + ): + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed." + ) + result["error_type"] = "executable_checker:wrong_result_real_time" + result["model_executed_output"] = exec_output + return result + else: + result["valid"] = False + result["error"].append( # type: ignore[attr-defined] + f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria." + ) + result["error_type"] = "executable_checker:wrong_result_real_time" + result["model_executed_output"] = exec_output + return result + + else: + # structural match + pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check) + if not pattern_match_result["valid"]: + return pattern_match_result + + return result + + +def executable_checker_parallel_no_order( + decoded_result: list, expected_exec_result: list, expected_exec_result_type: list +): + if len(decoded_result) != len(expected_exec_result): + return { + "valid": False, + "error": [ + f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}." + ], + "error_type": "value_error:exec_result_count", + } + + matched_indices = [] + for i in range(len(expected_exec_result)): + all_errors = [] + for index in range(len(decoded_result)): + if index in matched_indices: + continue + + result = executable_checker_simple( + decoded_result[index], + expected_exec_result[i], + expected_exec_result_type[i], + False, + ) + + if result["valid"]: + matched_indices.append(index) + break + else: + all_errors.append( + { + f"Model Result Index {index}": { + "sub_error": result["error"], + "sub_error_type": result["error_type"], + "model_executed_output": ( + result["model_executed_output"] if "model_executed_output" in result else None + ), + } + } + ) + + if not result["valid"]: + considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices] + all_errors.insert( + 0, + f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] + ) + return { + "valid": False, + "error": all_errors, + "error_type": "executable_checker:cannot_find_match", + } + + return {"valid": True, "error": [], "error_type": "executable_checker:unclear"} + + +#### Main function #### +def executable_checker_rest(func_call, idx): + # Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used. + EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution + with open(EVAL_GROUND_TRUTH_PATH, "r") as f: + EVAL_GROUND_TRUTH = f.readlines() + if "https://geocode.maps.co" in func_call: + time.sleep(2) + if "requests_get" in func_call: + func_call = func_call.replace("requests_get", "requests.get") + try: + response = eval(func_call) + except Exception as e: + return { + "valid": False, + "error": [f"Execution failed. {str(e)}"], + "error_type": "executable_checker_rest:execution_error", + } + + try: + if response.status_code == 200: + eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx]) + try: + if isinstance(eval_GT_json, dict): + if isinstance(response.json(), dict): + if set(eval_GT_json.keys()) == set(response.json().keys()): + return {"valid": True, "error": [], "error_type": ""} + return { + "valid": False, + "error": ["Key inconsistency"], + "error_type": "executable_checker_rest:wrong_key", + } + return { + "valid": False, + "error": [f"Expected dictionary, but got {type(response.json())}"], + "error_type": "executable_checker_rest:wrong_type", + } + + elif isinstance(eval_GT_json, list): + if isinstance(response.json(), list): + if len(eval_GT_json) != len(response.json()): + return { + "valid": False, + "error": [f"Response list length inconsistency."], + "error_type": "value_error:exec_result_rest_count", + } + + else: + for i in range(len(eval_GT_json)): + if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()): + return { + "valid": False, + "error": [f"Key inconsistency"], + "error_type": "executable_checker_rest:wrong_key", + } + + return {"valid": True, "error": []} + else: + return { + "valid": False, + "error": [f"Expected list, but got {type(response.json())}"], + "error_type": "executable_checker_rest:wrong_type", + } + return { + "valid": False, + "error": [f"Expected dict or list, but got {type(response.json())}"], + "error_type": "executable_checker_rest:wrong_type", + } + except Exception as e: + return { + "valid": False, + "error": [ + f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}" + ], + "error_type": "executable_checker_rest:response_format_error", + } + else: + return { + "valid": False, + "error": [f"Execution result status code is not 200, got {response.status_code}"], + "error_type": "executable_checker_rest:wrong_status_code", + } + except Exception as e: + return { + "valid": False, + "error": [f"Cannot get status code of the response. Error: {str(e)}"], + "error_type": "executable_checker_rest:cannot_get_status_code", + } + + +def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name): + if "parallel" in test_category: + return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name) + + elif "multiple" in test_category: + return multiple_function_checker(func_description, model_output, possible_answer, language, model_name) + + else: + if len(model_output) != 1: + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "simple_function_checker:wrong_count", + } + + return simple_function_checker( + func_description[0], + model_output[0], + possible_answer[0], + language, + model_name, + ) + + +def exec_checker(decoded_result: list, func_description: dict, test_category: str): + if "multiple" in test_category or "parallel" in test_category: + return executable_checker_parallel_no_order( + decoded_result, + func_description["execution_result"], + func_description["execution_result_type"], + ) + + else: + if len(decoded_result) != 1: + return { + "valid": False, + "error": ["Wrong number of functions."], + "error_type": "simple_exec_checker:wrong_count", + } + return executable_checker_simple( + decoded_result[0], + func_description["execution_result"][0], + func_description["execution_result_type"][0], + False, + ) + + +def is_empty_output(decoded_output): + # This function is a patch to the ast decoder for relevance detection + # Sometimes the ast decoder will parse successfully, but the input doens't really have a function call + # [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct) + if not is_function_calling_format_output(decoded_output): + return True + if len(decoded_output) == 0: + return True + if len(decoded_output) == 1 and len(decoded_output[0]) == 0: + return True + + +def is_function_calling_format_output(decoded_output): + # Ensure the output is a list of dictionaries + if type(decoded_output) == list: + for item in decoded_output: + if type(item) != dict: + return False + return True + return False diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py new file mode 100644 index 000000000..ed97ee360 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Tree-sitter changes its API with unfortunate frequency. Modules that need it should +import it from here so that we can centrally manage things as necessary. +""" + +# These currently work with tree-sitter 0.23.0 +# NOTE: Don't import tree-sitter or any of the language modules in the main module +# because not all environments have them. Import lazily inside functions where needed. + +import importlib +import typing + +if typing.TYPE_CHECKING: + import tree_sitter + + +def get_language(language: str) -> "tree_sitter.Language": + import tree_sitter + + language_module_name = f"tree_sitter_{language}" + try: + language_module = importlib.import_module(language_module_name) + except ModuleNotFoundError as exc: + raise ValueError( + f"Language {language} is not found. Please install the tree-sitter-{language} package." + ) from exc + return tree_sitter.Language(language_module.language()) + + +def get_parser(language: str, **kwargs) -> "tree_sitter.Parser": + import tree_sitter + + lang = get_language(language) + return tree_sitter.Parser(lang, **kwargs) diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 6901c3741..755d30382 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.eval, provider_type="inline::meta-reference", - pip_packages=[], + pip_packages=["tree_sitter"], module="llama_stack.providers.inline.eval.meta_reference", config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig", api_dependencies=[ diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index 3d14c4148..eb9d9dd60 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -23,6 +23,10 @@ class ColumnName(Enum): generated_answer = "generated_answer" context = "context" dialog = "dialog" + function = "function" + language = "language" + id = "id" + ground_truth = "ground_truth" VALID_SCHEMAS_FOR_SCORING = [ @@ -37,6 +41,15 @@ VALID_SCHEMAS_FOR_SCORING = [ ColumnName.generated_answer.value: StringType(), ColumnName.context.value: StringType(), }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.function.value: StringType(), + ColumnName.language.value: StringType(), + ColumnName.id.value: StringType(), + ColumnName.ground_truth.value: StringType(), + }, ] VALID_SCHEMAS_FOR_EVAL = [ @@ -50,6 +63,15 @@ VALID_SCHEMAS_FOR_EVAL = [ ColumnName.expected_answer.value: StringType(), ColumnName.completion_input.value: CompletionInputType(), }, + { + ColumnName.input_query.value: StringType(), + ColumnName.expected_answer.value: StringType(), + ColumnName.generated_answer.value: StringType(), + ColumnName.function.value: StringType(), + ColumnName.language.value: StringType(), + ColumnName.id.value: StringType(), + ColumnName.ground_truth.value: StringType(), + }, ] diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 2b40797f9..17f5b8ee7 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -226,6 +226,22 @@ def get_distribution_template() -> DistributionTemplate: "chat_completion_input": {"type": "string"}, }, ), + DatasetInput( + dataset_id="bfcl", + provider_id="huggingface", + url=URL(uri="https://huggingface.co/datasets/llamastack/bfcl_v3"), + metadata={ + "path": "llamastack/bfcl_v3", + "split": "train", + }, + dataset_schema={ + "function": {"type": "string"}, + "language": {"type": "string"}, + "ground_truth": {"type": "string"}, + "id": {"type": "string"}, + "chat_completion_input": {"type": "string"}, + }, + ), ] default_benchmarks = [ @@ -249,6 +265,11 @@ def get_distribution_template() -> DistributionTemplate: dataset_id="math_500", scoring_functions=["basic::regex_parser_math_response"], ), + BenchmarkInput( + benchmark_id="meta-reference-bfcl", + dataset_id="bfcl", + scoring_functions=["basic::bfcl"], + ), ] return DistributionTemplate( name=name, diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 5ef25435b..6961f8022 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -216,6 +216,24 @@ datasets: split: test dataset_id: math_500 provider_id: huggingface +- dataset_schema: + function: + type: string + language: + type: string + ground_truth: + type: string + id: + type: string + chat_completion_input: + type: string + url: + uri: https://huggingface.co/datasets/llamastack/bfcl_v3 + metadata: + path: llamastack/bfcl_v3 + split: train + dataset_id: bfcl + provider_id: huggingface scoring_fns: [] benchmarks: - dataset_id: simpleqa @@ -238,6 +256,11 @@ benchmarks: - basic::regex_parser_math_response metadata: {} benchmark_id: meta-reference-math-500 +- dataset_id: bfcl + scoring_functions: + - basic::bfcl + metadata: {} + benchmark_id: meta-reference-bfcl tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/requirements.txt b/requirements.txt index ae8a0af9f..3c382ad84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ distro==1.9.0 exceptiongroup==1.2.2 ; python_full_version < '3.11' filelock==3.17.0 fire==0.7.0 -fsspec==2025.2.0 +fsspec==2024.12.0 h11==0.14.0 httpcore==1.0.7 httpx==0.28.1 diff --git a/uv.lock b/uv.lock index 9ec3680f8..207f5981f 100644 --- a/uv.lock +++ b/uv.lock @@ -769,11 +769,11 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.2.0" +version = "2024.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b5/79/68612ed99700e6413de42895aa725463e821a6b3be75c87fcce1b4af4c70/fsspec-2025.2.0.tar.gz", hash = "sha256:1c24b16eaa0a1798afa0337aa0db9b256718ab2a89c425371f5628d22c3b6afd", size = 292283 } +sdist = { url = "https://files.pythonhosted.org/packages/ee/11/de70dee31455c546fbc88301971ec03c328f3d1138cfba14263f651e9551/fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f", size = 291600 } wheels = [ - { url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484 }, + { url = "https://files.pythonhosted.org/packages/de/86/5486b0188d08aa643e127774a99bac51ffa6cf343e3deb0583956dca5b22/fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2", size = 183862 }, ] [package.optional-dependencies]