forked from phoenix-oss/llama-stack-mirror
199 lines
7.2 KiB
Python
199 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# top-level folder for each specific model found within the models/ directory at
|
|
# the top-level of this source tree.
|
|
import ast
|
|
import json
|
|
import re
|
|
from typing import Optional, Tuple
|
|
|
|
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
|
|
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
|
|
|
|
|
def is_json(s):
|
|
try:
|
|
parsed = json.loads(s)
|
|
# Return True for valid objects and not for ints, strings, etc
|
|
return isinstance(parsed, dict)
|
|
except json.JSONDecodeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_valid_python_list(input_string):
|
|
"""Check if the input string is a valid Python list of function calls"""
|
|
try:
|
|
# Try to parse the string
|
|
tree = ast.parse(input_string)
|
|
|
|
# Check if it's a single expression
|
|
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
|
return False
|
|
|
|
# Check if the expression is a list
|
|
expr = tree.body[0].value
|
|
if not isinstance(expr, ast.List):
|
|
return False
|
|
|
|
# Check if the list is empty
|
|
if len(expr.elts) == 0:
|
|
return False
|
|
|
|
# Check if all elements in the list are function calls
|
|
for element in expr.elts:
|
|
if not isinstance(element, ast.Call):
|
|
return False
|
|
|
|
# Check if the function call has a valid name
|
|
if not isinstance(element.func, ast.Name):
|
|
return False
|
|
|
|
# Check if all arguments are keyword arguments
|
|
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
|
return False
|
|
|
|
return True
|
|
|
|
except SyntaxError:
|
|
# If parsing fails, it's not a valid Python expression
|
|
return False
|
|
|
|
|
|
def parse_python_list_for_function_calls(input_string):
|
|
"""
|
|
Parse a Python list of function calls and
|
|
return a list of tuples containing the function name and arguments
|
|
"""
|
|
# Parse the string into an AST
|
|
tree = ast.parse(input_string)
|
|
|
|
# Ensure the input is a list
|
|
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
|
raise ValueError("Input must be a list of function calls")
|
|
|
|
result = []
|
|
|
|
# Iterate through each function call in the list
|
|
for node in tree.body[0].value.elts:
|
|
if isinstance(node, ast.Call):
|
|
function_name = node.func.id
|
|
function_args = {}
|
|
|
|
# Extract keyword arguments
|
|
for keyword in node.keywords:
|
|
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
|
|
|
result.append((function_name, function_args))
|
|
|
|
return result
|
|
|
|
|
|
class ToolUtils:
|
|
@staticmethod
|
|
def is_builtin_tool_call(message_body: str) -> bool:
|
|
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
|
return match is not None
|
|
|
|
@staticmethod
|
|
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
|
# Find the first match in the text
|
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
|
|
|
# Check if a match is found and return it
|
|
if match:
|
|
tool_name = match.group("tool_name")
|
|
query = match.group("query")
|
|
return tool_name, query
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
|
# NOTE: Custom function too calls are still experimental
|
|
# Sometimes, response is of the form
|
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
|
# and some times
|
|
# <function=function_name>(parameters)</function>
|
|
|
|
# Find the first match in the text
|
|
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
|
if match:
|
|
tool_name = match.group("function_name")
|
|
query = match.group("args")
|
|
try:
|
|
return tool_name, json.loads(query.replace("'", '"'))
|
|
except Exception as e:
|
|
print("Exception while parsing json query for custom tool call", query, e)
|
|
return None
|
|
elif is_json(message_body):
|
|
response = json.loads(message_body)
|
|
if ("type" in response and response["type"] == "function") or ("name" in response):
|
|
function_name = response["name"]
|
|
args = response["parameters"]
|
|
return function_name, args
|
|
else:
|
|
return None
|
|
elif is_valid_python_list(message_body):
|
|
res = parse_python_list_for_function_calls(message_body)
|
|
# FIXME: Enable multiple tool calls
|
|
return res[0]
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
|
if t.tool_name == BuiltinTool.brave_search:
|
|
q = t.arguments["query"]
|
|
return f'brave_search.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
|
q = t.arguments["query"]
|
|
return f'wolfram_alpha.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.photogen:
|
|
q = t.arguments["query"]
|
|
return f'photogen.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
|
return t.arguments["code"]
|
|
else:
|
|
fname = t.tool_name
|
|
|
|
if tool_prompt_format == ToolPromptFormat.json:
|
|
return json.dumps(
|
|
{
|
|
"type": "function",
|
|
"name": fname,
|
|
"parameters": t.arguments,
|
|
}
|
|
)
|
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
|
args = json.dumps(t.arguments)
|
|
return f"<function={fname}>{args}</function>"
|
|
|
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
|
|
|
def format_value(value: RecursiveType) -> str:
|
|
if isinstance(value, str):
|
|
return f'"{value}"'
|
|
elif isinstance(value, (int, float, bool)) or value is None:
|
|
return str(value)
|
|
elif isinstance(value, list):
|
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
|
elif isinstance(value, dict):
|
|
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
|
else:
|
|
raise ValueError(f"Unsupported type: {type(value)}")
|
|
|
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
|
return f"[{fname}({args_str})]"
|
|
else:
|
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|