llama-stack-mirror/llama_stack/models/llama/llama3/tool_utils.py
Ben Browning 7641a5cd0b
fix: 100% OpenAI API verification for together and fireworks (#1946)
# What does this PR do?

TLDR: Changes needed to get 100% passing tests for OpenAI API
verification tests when run against Llama Stack with the `together`,
`fireworks`, and `openai` providers. And `groq` is better than before,
at 88% passing.

This cleans up the OpenAI API support for image message types
(specifically `image_url` types) and handling of the `response_format`
chat completion parameter. Both of these required a few more Pydantic
model definitions in our Inference API, just to move from the
not-quite-right stubs I had in place to something fleshed out to match
the actual OpenAI API specs.

As part of testing this, I also found and fixed a bug in the litellm
implementation of openai_completion and openai_chat_completion, so the
providers based on those should actually be working now.

The method `prepare_openai_completion_params` in
`llama_stack/providers/utils/inference/openai_compat.py` was improved to
actually recursively clean up input parameters, including handling of
lists, dicts, and dumping of Pydantic models to dicts. These changes
were required to get to 100% passing tests on the OpenAI API
verification against the `openai` provider.

With the above, the together.ai provider was passing as well as it is
without Llama Stack. But, since we have Llama Stack in the middle, I
took the opportunity to clean up the together.ai provider so that it now
also passes the OpenAI API spec tests we have at 100%. That means
together.ai is now passing our verification test better when using an
OpenAI client talking to Llama Stack than it is when hitting together.ai
directly, without Llama Stack in the middle.

And, another round of work for Fireworks to improve translation of
incoming OpenAI chat completion requests to Llama Stack chat completion
requests gets the fireworks provider passing at 100%. The server-side
fireworks.ai tool calling support with OpenAI chat completions and Llama
4 models isn't great yet, but by pointing the OpenAI clients at Llama
Stack's API we can clean things up and get everything working as
expected for Llama 4 models.

## Test Plan

### OpenAI API Verification Tests

I ran the OpenAI API verification tests as below and 100% of the tests
passed.

First, start a Llama Stack server that runs the `openai` provider with
the `gpt-4o` and `gpt-4o-mini` models deployed. There's not a template
setup to do this out of the box, so I added a
`tests/verifications/openai-api-verification-run.yaml` to do this.

First, ensure you have the necessary API key environment variables set:

```
export TOGETHER_API_KEY="..."
export FIREWORKS_API_KEY="..."
export OPENAI_API_KEY="..."
```

Then, run a Llama Stack server that serves up all these providers:

```
llama stack run \
      --image-type venv \
      tests/verifications/openai-api-verification-run.yaml
```

Finally, generate a new verification report against all these providers,
both with and without the Llama Stack server in the middle.

```
python tests/verifications/generate_report.py \
      --run-tests \
      --provider \
        together \
        fireworks \
        groq \
        openai \
        together-llama-stack \
        fireworks-llama-stack \
        groq-llama-stack \
        openai-llama-stack
```

You'll see that most of the configurations with Llama Stack in the
middle now pass at 100%, even though some of them do not pass at 100%
when hitting the backend provider's API directly with an OpenAI client.

### OpenAI Completion Integration Tests with vLLM:

I also ran the smaller `test_openai_completion.py` test suite (that's
not yet merged with the verification tests) on multiple of the
providers, since I had to adjust the method signature of
openai_chat_completion a bit and thus had to touch lots of these
providers to match. Here's the tests I ran there, all passing:

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct"
```

### OpenAI Completion Integration Tests with ollama

```
INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0"
```

### OpenAI Completion Integration Tests with together.ai

```
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" llama stack build --template together --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct-Turbo"
```

### OpenAI Completion Integration Tests with fireworks.ai

```
INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" llama stack build --template fireworks --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.1-8B-Instruct"

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-04-14 08:56:29 -07:00

267 lines
9.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
def is_json(s):
try:
parsed = json.loads(s)
# Return True for valid objects and not for ints, strings, etc
return isinstance(parsed, dict)
except json.JSONDecodeError:
return False
return True
def parse_llama_tool_call_format(input_string):
"""
Parse tool calls in the format:
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
"""
# Strip outer brackets and whitespace
input_string = input_string.strip()
if not (input_string.startswith("[") and input_string.endswith("]")):
return None
content = input_string[1:-1].strip()
if not content:
return None
result = []
# State variables for parsing
pos = 0
length = len(content)
while pos < length:
# Find function name
name_end = content.find("(", pos)
if name_end == -1:
break
func_name = content[pos:name_end].strip()
# Find closing parenthesis for this function call
paren_level = 1
args_start = name_end + 1
args_end = args_start
while args_end < length and paren_level > 0:
if content[args_end] == "(":
paren_level += 1
elif content[args_end] == ")":
paren_level -= 1
args_end += 1
if paren_level != 0:
# Unmatched parentheses
return None
# Parse arguments
args_str = content[args_start : args_end - 1].strip()
args_dict = {}
if args_str:
# Split by commas, but respect nested structures
parts = []
part_start = 0
in_quotes = False
quote_char = None
nested_level = 0
for i, char in enumerate(args_str):
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
quote_char = None
elif not in_quotes:
if char in ("{", "["):
nested_level += 1
elif char in ("}", "]"):
nested_level -= 1
elif char == "," and nested_level == 0:
parts.append(args_str[part_start:i].strip())
part_start = i + 1
parts.append(args_str[part_start:].strip())
# Process each key=value pair
for part in parts:
if "=" in part:
key, value = part.split("=", 1)
key = key.strip()
value = value.strip()
# Try to convert value to appropriate Python type
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
# String
value = value[1:-1]
elif value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
elif value.startswith("{") and value.endswith("}"):
# This is a nested dictionary
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
elif value.startswith("[") and value.endswith("]"):
# This is a nested list
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
else:
# Try to convert to number
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if not a valid number
pass
args_dict[key] = value
result.append((func_name, args_dict))
# Move to the next function call
pos = args_end
# Skip the comma between function calls if present
if pos < length and content[pos] == ",":
pos += 1
return result if result else None
class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
# Check if a match is found and return it
if match:
tool_name = match.group("tool_name")
query = match.group("query")
return tool_name, query
else:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
# and some times
# <function=function_name>(parameters)</function>
# Find the first match in the text
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
if match:
tool_name = match.group("function_name")
query = match.group("args")
try:
return tool_name, json.loads(query.replace("'", '"'))
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or (
"name" in response and "parameters" in response
):
function_name = response["name"]
args = response["parameters"]
return function_name, args
else:
return None
elif function_calls := parse_llama_tool_call_format(message_body):
# FIXME: Enable multiple tool calls
return function_calls[0]
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
return None
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search:
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen:
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
else:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
return json.dumps(
{
"type": "function",
"name": fname,
"parameters": t.arguments,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
elif tool_prompt_format == ToolPromptFormat.python_list:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"
elif isinstance(value, dict):
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
else:
raise ValueError(f"Unsupported type: {type(value)}")
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")