diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..704e03ee1 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -167,7 +167,7 @@ def parse_llama_tool_call_format(input_string): class ToolUtils: @staticmethod def is_builtin_tool_call(message_body: str) -> bool: - match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body) + match = re.search(BUILTIN_TOOL_PATTERN, message_body) return match is not None @staticmethod @@ -184,7 +184,7 @@ class ToolUtils: return None @staticmethod - def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None: + def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, dict] | None: # NOTE: Custom function too calls are still experimental # Sometimes, response is of the form # {"type": "function", "name": "function_name", "parameters": {...} @@ -197,7 +197,11 @@ class ToolUtils: tool_name = match.group("function_name") query = match.group("args") try: - return tool_name, json.loads(query.replace("'", '"')) + parsed_args = json.loads(query.replace("'", '"')) + if isinstance(parsed_args, dict): + return tool_name, parsed_args + else: + return None except Exception as e: print("Exception while parsing json query for custom tool call", query, e) return None @@ -208,12 +212,22 @@ class ToolUtils: ): function_name = response["name"] args = response["parameters"] - return function_name, args + if isinstance(args, dict): + return function_name, args + else: + return None else: return None elif function_calls := parse_llama_tool_call_format(message_body): # FIXME: Enable multiple tool calls - return function_calls[0] + if function_calls and len(function_calls) > 0: + func_name, args_dict = function_calls[0] + if isinstance(args_dict, dict): + return func_name, args_dict + else: + return None + else: + return None else: logger.debug(f"Did not parse tool call from message body: {message_body}") return None @@ -221,29 +235,64 @@ class ToolUtils: @staticmethod def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str: if t.tool_name == BuiltinTool.brave_search: - q = t.arguments["query"] - return f'brave_search.call(query="{q}")' + if isinstance(t.arguments, dict): + q = t.arguments["query"] + return f'brave_search.call(query="{q}")' + else: + # Handle string arguments + return f'brave_search.call(query="{t.arguments}")' elif t.tool_name == BuiltinTool.wolfram_alpha: - q = t.arguments["query"] - return f'wolfram_alpha.call(query="{q}")' + if isinstance(t.arguments, dict): + q = t.arguments["query"] + return f'wolfram_alpha.call(query="{q}")' + else: + # Handle string arguments + return f'wolfram_alpha.call(query="{t.arguments}")' elif t.tool_name == BuiltinTool.photogen: - q = t.arguments["query"] - return f'photogen.call(query="{q}")' + if isinstance(t.arguments, dict): + q = t.arguments["query"] + return f'photogen.call(query="{q}")' + else: + # Handle string arguments + return f'photogen.call(query="{t.arguments}")' elif t.tool_name == BuiltinTool.code_interpreter: - return t.arguments["code"] + if isinstance(t.arguments, dict): + code = t.arguments["code"] + return str(code) + else: + # Handle string arguments + return str(t.arguments) else: fname = t.tool_name if tool_prompt_format == ToolPromptFormat.json: + # For JSON format, we need to handle both string and dict arguments + if isinstance(t.arguments, str): + # Try to parse string as JSON + try: + parsed_args = json.loads(t.arguments) + except json.JSONDecodeError: + parsed_args = {"value": t.arguments} + else: + parsed_args = t.arguments + return json.dumps( { "type": "function", "name": fname, - "parameters": t.arguments, + "parameters": parsed_args, } ) elif tool_prompt_format == ToolPromptFormat.function_tag: - args = json.dumps(t.arguments) + if isinstance(t.arguments, str): + # Try to parse string as JSON + try: + parsed_args = json.loads(t.arguments) + args = json.dumps(parsed_args) + except json.JSONDecodeError: + args = json.dumps({"value": t.arguments}) + else: + args = json.dumps(t.arguments) return f"{args}" elif tool_prompt_format == ToolPromptFormat.python_list: @@ -256,11 +305,22 @@ class ToolUtils: elif isinstance(value, list): return f"[{', '.join(format_value(v) for v in value)}]" elif isinstance(value, dict): - return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}" + return f"{{{', '.join(f'"{k}"={format_value(v)}' for k, v in value.items())}}}" else: raise ValueError(f"Unsupported type: {type(value)}") - args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items()) + if isinstance(t.arguments, str): + # Try to parse string as JSON + try: + parsed_args = json.loads(t.arguments) + if isinstance(parsed_args, dict): + args_str = ", ".join(f"{k}={format_value(v)}" for k, v in parsed_args.items()) + else: + args_str = f"value={format_value(parsed_args)}" + except json.JSONDecodeError: + args_str = f'value="{t.arguments}"' + else: + args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items()) return f"[{fname}({args_str})]" else: raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}") diff --git a/pyproject.toml b/pyproject.toml index 72f3a323f..019bcf0b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,7 @@ exclude = [ "^llama_stack/models/llama/llama3/chat_format\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$", - "^llama_stack/models/llama/llama3/tool_utils\\.py$", + "^llama_stack/models/llama/llama3_3/prompts\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",