mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: [new open benchmark] BFCL_v3 (#1578)
# What does this PR do? create a new dataset BFCL_v3 from https://gorilla.cs.berkeley.edu/blogs/13_bfcl_v3_multi_turn.html overall each question asks the model to perform a task described in natural language, and additionally a set of available functions and their schema are given for the model to choose from. the model is required to write the function call form including function name and parameters , to achieve the stated purpose. the results are validated against provided ground truth, to make sure that the generated function call and the ground truth function call are syntactically and semantically equivalent, by checking their AST . ## Test Plan start server by ``` llama stack run ./llama_stack/templates/ollama/run.yaml ``` then send traffic ``` llama-stack-client eval run-benchmark "bfcl" --model-id meta-llama/Llama-3.2-3B-Instruct --output-dir /tmp/gpqa --num-examples 2 ``` [//]: # (## Documentation)
This commit is contained in:
parent
78d4872c0c
commit
a626b7bce3
15 changed files with 1546 additions and 9 deletions
|
@ -30,6 +30,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"cerebras": [
|
||||
|
@ -62,6 +63,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -97,6 +99,7 @@
|
|||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -132,6 +135,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -168,6 +172,7 @@
|
|||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -203,6 +208,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -236,6 +242,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"hf-endpoint": [
|
||||
|
@ -270,6 +277,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"hf-serverless": [
|
||||
|
@ -304,6 +312,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -344,6 +353,7 @@
|
|||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"zmq"
|
||||
],
|
||||
|
@ -385,6 +395,7 @@
|
|||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"zmq"
|
||||
],
|
||||
|
@ -417,6 +428,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
|
@ -451,6 +463,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"open-benchmark": [
|
||||
|
@ -485,6 +498,7 @@
|
|||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"passthrough": [
|
||||
|
@ -517,6 +531,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -551,6 +566,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -616,6 +632,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -651,6 +668,7 @@
|
|||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -685,6 +703,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
|
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
@ -168,10 +168,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
@ -22,12 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,296 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
|
@ -0,0 +1,989 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[],
|
||||
pip_packages=["tree_sitter"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -23,6 +23,10 @@ class ColumnName(Enum):
|
|||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
dialog = "dialog"
|
||||
function = "function"
|
||||
language = "language"
|
||||
id = "id"
|
||||
ground_truth = "ground_truth"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
|
@ -37,6 +41,15 @@ VALID_SCHEMAS_FOR_SCORING = [
|
|||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.context.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
VALID_SCHEMAS_FOR_EVAL = [
|
||||
|
@ -50,6 +63,15 @@ VALID_SCHEMAS_FOR_EVAL = [
|
|||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -226,6 +226,22 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="bfcl",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/bfcl_v3"),
|
||||
metadata={
|
||||
"path": "llamastack/bfcl_v3",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"function": {"type": "string"},
|
||||
"language": {"type": "string"},
|
||||
"ground_truth": {"type": "string"},
|
||||
"id": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
default_benchmarks = [
|
||||
|
@ -249,6 +265,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-bfcl",
|
||||
dataset_id="bfcl",
|
||||
scoring_functions=["basic::bfcl"],
|
||||
),
|
||||
]
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
|
@ -216,6 +216,24 @@ datasets:
|
|||
split: test
|
||||
dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
function:
|
||||
type: string
|
||||
language:
|
||||
type: string
|
||||
ground_truth:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
|
||||
metadata:
|
||||
path: llamastack/bfcl_v3
|
||||
split: train
|
||||
dataset_id: bfcl
|
||||
provider_id: huggingface
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- dataset_id: simpleqa
|
||||
|
@ -238,6 +256,11 @@ benchmarks:
|
|||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
- dataset_id: bfcl
|
||||
scoring_functions:
|
||||
- basic::bfcl
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-bfcl
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
|
@ -12,7 +12,7 @@ distro==1.9.0
|
|||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||
filelock==3.17.0
|
||||
fire==0.7.0
|
||||
fsspec==2025.2.0
|
||||
fsspec==2024.12.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.28.1
|
||||
|
|
6
uv.lock
generated
6
uv.lock
generated
|
@ -769,11 +769,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "fsspec"
|
||||
version = "2025.2.0"
|
||||
version = "2024.12.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/79/68612ed99700e6413de42895aa725463e821a6b3be75c87fcce1b4af4c70/fsspec-2025.2.0.tar.gz", hash = "sha256:1c24b16eaa0a1798afa0337aa0db9b256718ab2a89c425371f5628d22c3b6afd", size = 292283 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ee/11/de70dee31455c546fbc88301971ec03c328f3d1138cfba14263f651e9551/fsspec-2024.12.0.tar.gz", hash = "sha256:670700c977ed2fb51e0d9f9253177ed20cbde4a3e5c0283cc5385b5870c8533f", size = 291600 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484 },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/86/5486b0188d08aa643e127774a99bac51ffa6cf343e3deb0583956dca5b22/fsspec-2024.12.0-py3-none-any.whl", hash = "sha256:b520aed47ad9804237ff878b504267a3b0b441e97508bd6d2d8774e3db85cee2", size = 183862 },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue