mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 16:53:12 +00:00
This gets the fireworks provider passing 100% of our OpenAI API
verification tests when run against a Llama Stack server using the
fireworks provider. Testing against Fireworks directly, without Llama
Stack in the middle, has a lower pass rate.
The main changes are are in how we divert Llama model OpenAI chat
completion requests to the Llama Stack chat completion API (vs
OpenAI), which applies all the client-side formatting necessary to get
tool calls working properly on Fireworks.
A side-effect of this work is any provider using the
OpenAIChatCompletionToLlamaStackMixin (renamed from
OpenAIChatCompletioonUnsupportedMixin) will also get a better
conversion from OpenAI to Llama Stack, including streaming and
non-stream responses.
A small change was required to
`llama_stack/models/llama/llama3/tool_utils.py` to get tests to 100%
because code there was incorrectly assuming any JSON response with a
`name` key was a tool call response. One of our verification tests
produces JSON keys with a `name` key that is not a tool call response,
so I tightened up the logic there to require both a `name` and
`parameters` key in the JSON response before it gets considered a
potential tool call. The `parameters` key was required by the code
anyway, but it wasn't explicitly checking for its existence.
Lastly, this adds some new verification test configs so we can see the
results of using OpenAI APIs against SaaS services directly compared
to hitting Llama Stack with a remote provider pointing at that SaaS
service.
You can run these verification tests like:
```
llama stack run \
--image-type venv \
tests/verifications/openai-api-verification-run.yaml
python tests/verifications/generate_report.py \
--run-tests \
--provider together fireworks openai \
together-llama-stack \
fireworks-llama-stack \
openai-llama-stack
```
Signed-off-by: Ben Browning <bbrownin@redhat.com>
267 lines
9.8 KiB
Python
267 lines
9.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import re
|
|
from typing import Optional, Tuple
|
|
|
|
from llama_stack.log import get_logger
|
|
|
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
|
|
|
logger = get_logger(name=__name__, category="inference")
|
|
|
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
|
|
|
|
|
def is_json(s):
|
|
try:
|
|
parsed = json.loads(s)
|
|
# Return True for valid objects and not for ints, strings, etc
|
|
return isinstance(parsed, dict)
|
|
except json.JSONDecodeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def parse_llama_tool_call_format(input_string):
|
|
"""
|
|
Parse tool calls in the format:
|
|
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
|
|
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
|
|
"""
|
|
# Strip outer brackets and whitespace
|
|
input_string = input_string.strip()
|
|
if not (input_string.startswith("[") and input_string.endswith("]")):
|
|
return None
|
|
|
|
content = input_string[1:-1].strip()
|
|
if not content:
|
|
return None
|
|
|
|
result = []
|
|
|
|
# State variables for parsing
|
|
pos = 0
|
|
length = len(content)
|
|
|
|
while pos < length:
|
|
# Find function name
|
|
name_end = content.find("(", pos)
|
|
if name_end == -1:
|
|
break
|
|
|
|
func_name = content[pos:name_end].strip()
|
|
|
|
# Find closing parenthesis for this function call
|
|
paren_level = 1
|
|
args_start = name_end + 1
|
|
args_end = args_start
|
|
|
|
while args_end < length and paren_level > 0:
|
|
if content[args_end] == "(":
|
|
paren_level += 1
|
|
elif content[args_end] == ")":
|
|
paren_level -= 1
|
|
args_end += 1
|
|
|
|
if paren_level != 0:
|
|
# Unmatched parentheses
|
|
return None
|
|
|
|
# Parse arguments
|
|
args_str = content[args_start : args_end - 1].strip()
|
|
args_dict = {}
|
|
|
|
if args_str:
|
|
# Split by commas, but respect nested structures
|
|
parts = []
|
|
part_start = 0
|
|
in_quotes = False
|
|
quote_char = None
|
|
nested_level = 0
|
|
|
|
for i, char in enumerate(args_str):
|
|
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
|
|
if not in_quotes:
|
|
in_quotes = True
|
|
quote_char = char
|
|
elif char == quote_char:
|
|
in_quotes = False
|
|
quote_char = None
|
|
elif not in_quotes:
|
|
if char in ("{", "["):
|
|
nested_level += 1
|
|
elif char in ("}", "]"):
|
|
nested_level -= 1
|
|
elif char == "," and nested_level == 0:
|
|
parts.append(args_str[part_start:i].strip())
|
|
part_start = i + 1
|
|
|
|
parts.append(args_str[part_start:].strip())
|
|
|
|
# Process each key=value pair
|
|
for part in parts:
|
|
if "=" in part:
|
|
key, value = part.split("=", 1)
|
|
key = key.strip()
|
|
value = value.strip()
|
|
|
|
# Try to convert value to appropriate Python type
|
|
if (value.startswith('"') and value.endswith('"')) or (
|
|
value.startswith("'") and value.endswith("'")
|
|
):
|
|
# String
|
|
value = value[1:-1]
|
|
elif value.lower() == "true":
|
|
value = True
|
|
elif value.lower() == "false":
|
|
value = False
|
|
elif value.lower() == "none":
|
|
value = None
|
|
elif value.startswith("{") and value.endswith("}"):
|
|
# This is a nested dictionary
|
|
try:
|
|
# Try to parse as JSON
|
|
value = json.loads(value.replace("'", '"'))
|
|
except json.JSONDecodeError:
|
|
# Keep as string if parsing fails
|
|
pass
|
|
elif value.startswith("[") and value.endswith("]"):
|
|
# This is a nested list
|
|
try:
|
|
# Try to parse as JSON
|
|
value = json.loads(value.replace("'", '"'))
|
|
except json.JSONDecodeError:
|
|
# Keep as string if parsing fails
|
|
pass
|
|
else:
|
|
# Try to convert to number
|
|
try:
|
|
if "." in value:
|
|
value = float(value)
|
|
else:
|
|
value = int(value)
|
|
except ValueError:
|
|
# Keep as string if not a valid number
|
|
pass
|
|
|
|
args_dict[key] = value
|
|
|
|
result.append((func_name, args_dict))
|
|
|
|
# Move to the next function call
|
|
pos = args_end
|
|
|
|
# Skip the comma between function calls if present
|
|
if pos < length and content[pos] == ",":
|
|
pos += 1
|
|
|
|
return result if result else None
|
|
|
|
|
|
class ToolUtils:
|
|
@staticmethod
|
|
def is_builtin_tool_call(message_body: str) -> bool:
|
|
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
|
return match is not None
|
|
|
|
@staticmethod
|
|
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
|
# Find the first match in the text
|
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
|
|
|
# Check if a match is found and return it
|
|
if match:
|
|
tool_name = match.group("tool_name")
|
|
query = match.group("query")
|
|
return tool_name, query
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
|
# NOTE: Custom function too calls are still experimental
|
|
# Sometimes, response is of the form
|
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
|
# and some times
|
|
# <function=function_name>(parameters)</function>
|
|
|
|
# Find the first match in the text
|
|
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
|
if match:
|
|
tool_name = match.group("function_name")
|
|
query = match.group("args")
|
|
try:
|
|
return tool_name, json.loads(query.replace("'", '"'))
|
|
except Exception as e:
|
|
print("Exception while parsing json query for custom tool call", query, e)
|
|
return None
|
|
elif is_json(message_body):
|
|
response = json.loads(message_body)
|
|
if ("type" in response and response["type"] == "function") or (
|
|
"name" in response and "parameters" in response
|
|
):
|
|
function_name = response["name"]
|
|
args = response["parameters"]
|
|
return function_name, args
|
|
else:
|
|
return None
|
|
elif function_calls := parse_llama_tool_call_format(message_body):
|
|
# FIXME: Enable multiple tool calls
|
|
return function_calls[0]
|
|
else:
|
|
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
|
if t.tool_name == BuiltinTool.brave_search:
|
|
q = t.arguments["query"]
|
|
return f'brave_search.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
|
q = t.arguments["query"]
|
|
return f'wolfram_alpha.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.photogen:
|
|
q = t.arguments["query"]
|
|
return f'photogen.call(query="{q}")'
|
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
|
return t.arguments["code"]
|
|
else:
|
|
fname = t.tool_name
|
|
|
|
if tool_prompt_format == ToolPromptFormat.json:
|
|
return json.dumps(
|
|
{
|
|
"type": "function",
|
|
"name": fname,
|
|
"parameters": t.arguments,
|
|
}
|
|
)
|
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
|
args = json.dumps(t.arguments)
|
|
return f"<function={fname}>{args}</function>"
|
|
|
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
|
|
|
def format_value(value: RecursiveType) -> str:
|
|
if isinstance(value, str):
|
|
return f'"{value}"'
|
|
elif isinstance(value, (int, float, bool)) or value is None:
|
|
return str(value)
|
|
elif isinstance(value, list):
|
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
|
elif isinstance(value, dict):
|
|
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
|
else:
|
|
raise ValueError(f"Unsupported type: {type(value)}")
|
|
|
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
|
return f"[{fname}({args_str})]"
|
|
else:
|
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|