diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index fc8287eb6..ef39ba0a5 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. -import ast import json import re from typing import Optional, Tuple @@ -35,80 +28,141 @@ def is_json(s): return True -def is_valid_python_list(input_string): - """Check if the input string is a valid Python list of function calls""" - try: - # Try to parse the string - tree = ast.parse(input_string) - - # Check if it's a single expression - if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr): - return False - - # Check if the expression is a list - expr = tree.body[0].value - if not isinstance(expr, ast.List): - return False - - # Check if the list is empty - if len(expr.elts) == 0: - return False - - # Check if all elements in the list are function calls - for element in expr.elts: - if not isinstance(element, ast.Call): - return False - - # Check if the function call has a valid name - if not isinstance(element.func, ast.Name): - return False - - # Check if all arguments are keyword arguments - if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords): - return False - - return True - - except SyntaxError: - # If parsing fails, it's not a valid Python expression - return False - - -def parse_python_list_for_function_calls(input_string): +def parse_llama_tool_call_format(input_string): """ - Parse a Python list of function calls and - return a list of tuples containing the function name and arguments - """ - # Parse the string into an AST - tree = ast.parse(input_string) + Parse tool calls in the format: + [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] - # Ensure the input is a list - if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List): - raise ValueError("Input must be a list of function calls") + Returns a list of (function_name, arguments_dict) tuples or None if parsing fails. + """ + # Strip outer brackets and whitespace + input_string = input_string.strip() + if not (input_string.startswith("[") and input_string.endswith("]")): + return None + + content = input_string[1:-1].strip() + if not content: + return None result = [] - # Iterate through each function call in the list - for node in tree.body[0].value.elts: - if isinstance(node, ast.Call): - function_name = node.func.id - function_args = {} + # State variables for parsing + pos = 0 + length = len(content) - # Extract keyword arguments - for keyword in node.keywords: - try: - function_args[keyword.arg] = ast.literal_eval(keyword.value) - except ValueError as e: - logger.error( - f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'" - ) - raise ValueError( - f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'" - ) from e + while pos < length: + # Find function name + name_end = content.find("(", pos) + if name_end == -1: + break - result.append((function_name, function_args)) + func_name = content[pos:name_end].strip() - return result + # Find closing parenthesis for this function call + paren_level = 1 + args_start = name_end + 1 + args_end = args_start + + while args_end < length and paren_level > 0: + if content[args_end] == "(": + paren_level += 1 + elif content[args_end] == ")": + paren_level -= 1 + args_end += 1 + + if paren_level != 0: + # Unmatched parentheses + return None + + # Parse arguments + args_str = content[args_start : args_end - 1].strip() + args_dict = {} + + if args_str: + # Split by commas, but respect nested structures + parts = [] + part_start = 0 + in_quotes = False + quote_char = None + nested_level = 0 + + for i, char in enumerate(args_str): + if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"): + if not in_quotes: + in_quotes = True + quote_char = char + elif char == quote_char: + in_quotes = False + quote_char = None + elif not in_quotes: + if char in ("{", "["): + nested_level += 1 + elif char in ("}", "]"): + nested_level -= 1 + elif char == "," and nested_level == 0: + parts.append(args_str[part_start:i].strip()) + part_start = i + 1 + + parts.append(args_str[part_start:].strip()) + + # Process each key=value pair + for part in parts: + if "=" in part: + key, value = part.split("=", 1) + key = key.strip() + value = value.strip() + + # Try to convert value to appropriate Python type + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + # String + value = value[1:-1] + elif value.lower() == "true": + value = True + elif value.lower() == "false": + value = False + elif value.lower() == "none": + value = None + elif value.startswith("{") and value.endswith("}"): + # This is a nested dictionary + try: + # Try to parse as JSON + value = json.loads(value.replace("'", '"')) + except json.JSONDecodeError: + # Keep as string if parsing fails + pass + elif value.startswith("[") and value.endswith("]"): + # This is a nested list + try: + # Try to parse as JSON + value = json.loads(value.replace("'", '"')) + except json.JSONDecodeError: + # Keep as string if parsing fails + pass + else: + # Try to convert to number + try: + if "." in value: + value = float(value) + else: + value = int(value) + except ValueError: + # Keep as string if not a valid number + pass + + args_dict[key] = value + + result.append((func_name, args_dict)) + + # Move to the next function call + pos = args_end + + # Skip the comma between function calls if present + if pos < length and content[pos] == ",": + pos += 1 + + return result if result else None class ToolUtils: @@ -156,11 +210,11 @@ class ToolUtils: return function_name, args else: return None - elif is_valid_python_list(message_body): - res = parse_python_list_for_function_calls(message_body) + elif function_calls := parse_llama_tool_call_format(message_body): # FIXME: Enable multiple tool calls - return res[0] + return function_calls[0] else: + logger.debug(f"Did not parse tool call from message body: {message_body}") return None @staticmethod diff --git a/tests/unit/models/llama/llama3/test_tool_utils.py b/tests/unit/models/llama/llama3/test_tool_utils.py new file mode 100644 index 000000000..f576953de --- /dev/null +++ b/tests/unit/models/llama/llama3/test_tool_utils.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.models.llama.llama3.tool_utils import ToolUtils + + +class TestMaybeExtractCustomToolCall: + def test_valid_single_tool_call(self): + input_string = '[get_weather(location="San Francisco", units="celsius")]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "get_weather" + assert result[1] == {"location": "San Francisco", "units": "celsius"} + + def test_valid_multiple_tool_calls(self): + input_string = '[search(query="python programming"), get_time(timezone="UTC")]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + # Note: maybe_extract_custom_tool_call currently only returns the first tool call + assert result is not None + assert len(result) == 2 + assert result[0] == "search" + assert result[1] == {"query": "python programming"} + + def test_different_value_types(self): + input_string = '[analyze_data(count=42, enabled=True, ratio=3.14, name="test", options=None)]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "analyze_data" + assert result[1] == {"count": 42, "enabled": True, "ratio": 3.14, "name": "test", "options": None} + + def test_nested_structures(self): + input_string = '[complex_function(filters={"min": 10, "max": 100}, tags=["important", "urgent"])]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + # This test checks that nested structures are handled + assert result is not None + assert len(result) == 2 + assert result[0] == "complex_function" + assert "filters" in result[1] + assert sorted(result[1]["filters"].items()) == sorted({"min": 10, "max": 100}.items()) + + assert "tags" in result[1] + assert result[1]["tags"] == ["important", "urgent"] + + def test_hyphenated_function_name(self): + input_string = '[weather-forecast(city="London")]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "weather-forecast" # Function name remains hyphenated + assert result[1] == {"city": "London"} + + def test_empty_input(self): + input_string = "[]" + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is None + + def test_invalid_format(self): + invalid_inputs = [ + 'get_weather(location="San Francisco")', # Missing outer brackets + '{get_weather(location="San Francisco")}', # Wrong outer brackets + '[get_weather(location="San Francisco"]', # Unmatched brackets + '[get_weather{location="San Francisco"}]', # Wrong inner brackets + "just some text", # Not a tool call format at all + ] + + for input_string in invalid_inputs: + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + assert result is None + + def test_quotes_handling(self): + input_string = '[search(query="Text with \\"quotes\\" inside")]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + # This test checks that escaped quotes are handled correctly + assert result is not None + + def test_single_quotes_in_arguments(self): + input_string = "[add-note(name='demonote', content='demonstrating Llama Stack and MCP integration')]" + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "add-note" # Function name remains hyphenated + assert result[1] == {"name": "demonote", "content": "demonstrating Llama Stack and MCP integration"} + + def test_json_format(self): + input_string = '{"type": "function", "name": "search_web", "parameters": {"query": "AI research"}}' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "search_web" + assert result[1] == {"query": "AI research"} + + def test_python_list_format(self): + input_string = "[calculate(x=10, y=20)]" + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "calculate" + assert result[1] == {"x": 10, "y": 20} + + def test_complex_nested_structures(self): + input_string = '[advanced_query(config={"filters": {"categories": ["books", "electronics"], "price_range": {"min": 10, "max": 500}}, "sort": {"field": "relevance", "order": "desc"}})]' + result = ToolUtils.maybe_extract_custom_tool_call(input_string) + + assert result is not None + assert len(result) == 2 + assert result[0] == "advanced_query" + + # Verify the overall structure + assert "config" in result[1] + assert isinstance(result[1]["config"], dict) + + # Verify the first level of nesting + config = result[1]["config"] + assert "filters" in config + assert "sort" in config + + # Verify the second level of nesting (filters) + filters = config["filters"] + assert "categories" in filters + assert "price_range" in filters + + # Verify the list within the dict + assert filters["categories"] == ["books", "electronics"] + + # Verify the nested dict within another dict + assert filters["price_range"]["min"] == 10 + assert filters["price_range"]["max"] == 500 + + # Verify the sort dictionary + assert config["sort"]["field"] == "relevance" + assert config["sort"]["order"] == "desc"