forked from phoenix-oss/llama-stack-mirror
feat: function tools in OpenAI Responses (#2094)
# What does this PR do? This is a combination of what was previously 3 separate PRs - #2069, #2075, and #2083. It turns out all 3 of those are needed to land a working function calling Responses implementation. The web search builtin tool was already working, but this wires in support for custom function calling. I ended up combining all three into one PR because they all had lots of merge conflicts, both with each other but also with #1806 that just landed. And, because landing any of them individually would have only left a partially working implementation merged. The new things added here are: * Storing of input items from previous responses and restoring of those input items when adding previous responses to the conversation state * Handling of multiple input item messages roles, not just "user" messages. * Support for custom tools passed into the Responses API to enable function calling outside of just the builtin websearch tool. Closes #2074 Closes #2080 ## Test Plan ### Unit Tests Several new unit tests were added, and they all pass. Ran via: ``` python -m pytest -s -v tests/unit/providers/agents/meta_reference/test_openai_responses.py ``` ### Responses API Verification Tests I ran our verification run.yaml against multiple providers to ensure we were getting a decent pass rate. Specifically, I ensured the new custom tool verification test passed across multiple providers and that the multi-turn examples passed across at least some of the providers (some providers struggle with the multi-turn workflows still). Running the stack setup for verification testing: ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` Together, passing 100% as an example: ``` pytest -s -v 'tests/verifications/openai_api/test_responses.py' --provider=together-llama-stack ``` ## Documentation We will need to start documenting the OpenAI APIs, but for now the Responses stuff is still rapidly evolving so delaying that. --------- Signed-off-by: Derek Higgins <derekh@redhat.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Derek Higgins <derekh@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
e0d10dd0b1
commit
8e316c9b1e
13 changed files with 1099 additions and 370 deletions
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion:
|
||||
fixture_path = os.path.join(FIXTURES_DIR, filename)
|
||||
|
||||
with open(fixture_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return OpenAIChatCompletion(**data)
|
|
@ -0,0 +1,9 @@
|
|||
id: chat-completion-123
|
||||
choices:
|
||||
- message:
|
||||
content: "Dublin"
|
||||
role: assistant
|
||||
finish_reason: stop
|
||||
index: 0
|
||||
created: 1234567890
|
||||
model: meta-llama/Llama-3.1-8B-Instruct
|
|
@ -0,0 +1,14 @@
|
|||
id: chat-completion-123
|
||||
choices:
|
||||
- message:
|
||||
tool_calls:
|
||||
- id: tool_call_123
|
||||
type: function
|
||||
function:
|
||||
name: web_search
|
||||
arguments: '{"query":"What is the capital of Ireland?"}'
|
||||
role: assistant
|
||||
finish_reason: stop
|
||||
index: 0
|
||||
created: 1234567890
|
||||
model: meta-llama/Llama-3.1-8B-Instruct
|
|
@ -4,27 +4,32 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
OpenAIResponsePreviousResponseWithInputItems,
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -65,21 +70,11 @@ def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api
|
|||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input."""
|
||||
# Setup
|
||||
input_text = "Hello, world!"
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_chat_completion = OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Hello! How can I help you?"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
)
|
||||
# Load the chat completion fixture
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
|
||||
# Execute
|
||||
|
@ -92,7 +87,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="Hello, world!", name=None)],
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
tools=None,
|
||||
stream=False,
|
||||
temperature=0.1,
|
||||
|
@ -100,55 +95,25 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
|
||||
assert result.output[0].content[0].text == "Hello! How can I help you?"
|
||||
assert isinstance(result.output[0], OpenAIResponseMessage)
|
||||
assert result.output[0].content[0].text == "Dublin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||
# Setup
|
||||
input_text = "What was the score of todays game?"
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_chat_completions = [
|
||||
OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionToolCall(
|
||||
id="tool_call_123",
|
||||
type="function",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search", arguments='{"query":"What was the score of todays game?"}'
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
),
|
||||
OpenAIChatCompletion(
|
||||
id="chat-completion-123",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="The score of todays game was 10-12"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
),
|
||||
]
|
||||
# Load the chat completion fixtures
|
||||
tool_call_completion = load_chat_completion_fixture("tool_call_completion.yaml")
|
||||
tool_response_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
|
||||
mock_inference_api.openai_chat_completion.side_effect = mock_chat_completions
|
||||
mock_inference_api.openai_chat_completion.side_effect = [
|
||||
tool_call_completion,
|
||||
tool_response_completion,
|
||||
]
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||
identifier="web_search",
|
||||
|
@ -163,7 +128,7 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
||||
status="completed",
|
||||
content="The score of todays game was 10-12",
|
||||
content="Dublin",
|
||||
)
|
||||
|
||||
# Execute
|
||||
|
@ -180,23 +145,172 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What was the score of todays game?"
|
||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "The score of todays game was 10-12"
|
||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
tool_name="web_search",
|
||||
kwargs={"query": "What was the score of todays game?"},
|
||||
kwargs={"query": "What is the capital of Ireland?"},
|
||||
)
|
||||
|
||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
assert len(result.output) >= 1
|
||||
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
||||
assert result.output[1].content[0].text == "The score of todays game was 10-12"
|
||||
assert isinstance(result.output[1], OpenAIResponseMessage)
|
||||
assert result.output[1].content[0].text == "Dublin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with multiple messages."""
|
||||
# Setup
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="developer", content="You are a helpful assistant", name=None),
|
||||
OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None),
|
||||
OpenAIResponseMessage(
|
||||
role="assistant",
|
||||
content=[
|
||||
OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"),
|
||||
OpenAIResponseInputMessageContentText(text="Dublin"),
|
||||
],
|
||||
name=None,
|
||||
),
|
||||
OpenAIResponseMessage(role="user", content="Which is the largest town in Ireland?", name=None),
|
||||
]
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input=input_messages,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Verify the the correct messages were sent to the inference API i.e.
|
||||
# All of the responses message were convered to the chat completion message objects
|
||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||
for i, m in enumerate(input_messages):
|
||||
if isinstance(m.content, str):
|
||||
assert inference_messages[i].content == m.content
|
||||
else:
|
||||
assert inference_messages[i].content[0].text == m.content[0].text
|
||||
assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert inference_messages[i].role == m.role
|
||||
if m.role == "user":
|
||||
assert isinstance(inference_messages[i], OpenAIUserMessageParam)
|
||||
elif m.role == "assistant":
|
||||
assert isinstance(inference_messages[i], OpenAIAssistantMessageParam)
|
||||
else:
|
||||
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepend_previous_response_none(openai_responses_impl):
|
||||
"""Test prepending no previous response to a new response."""
|
||||
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
|
||||
assert input == "fake_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input")
|
||||
async def test_prepend_previous_response_basic(get_previous_response_with_input, openai_responses_impl):
|
||||
"""Test prepending a basic previous response to a new response."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=[input_item_message])
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObject(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
)
|
||||
previous_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
get_previous_response_with_input.return_value = previous_response
|
||||
|
||||
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
|
||||
|
||||
assert len(input) == 3
|
||||
# Check for previous input
|
||||
assert isinstance(input[0], OpenAIResponseMessage)
|
||||
assert input[0].content[0].text == "fake_previous_input"
|
||||
# Check for previous output
|
||||
assert isinstance(input[1], OpenAIResponseMessage)
|
||||
assert input[1].content[0].text == "fake_response"
|
||||
# Check for new input
|
||||
assert isinstance(input[2], OpenAIResponseMessage)
|
||||
assert input[2].content == "fake_input"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input")
|
||||
async def test_prepend_previous_response_web_search(get_previous_response_with_input, openai_responses_impl):
|
||||
"""Test prepending a web search previous response to a new response."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
input_items = OpenAIResponseInputItemList(data=[input_item_message])
|
||||
output_web_search = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id="ws_123",
|
||||
status="completed",
|
||||
)
|
||||
output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_web_search_response")],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObject(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[output_web_search, output_message],
|
||||
status="completed",
|
||||
)
|
||||
previous_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
get_previous_response_with_input.return_value = previous_response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
|
||||
assert len(input) == 4
|
||||
# Check for previous input
|
||||
assert isinstance(input[0], OpenAIResponseMessage)
|
||||
assert input[0].content[0].text == "fake_previous_input"
|
||||
# Check for previous output web search tool call
|
||||
assert isinstance(input[1], OpenAIResponseOutputMessageWebSearchToolCall)
|
||||
# Check for previous output web search response
|
||||
assert isinstance(input[2], OpenAIResponseMessage)
|
||||
assert input[2].content[0].text == "fake_web_search_response"
|
||||
# Check for new input
|
||||
assert isinstance(input[3], OpenAIResponseMessage)
|
||||
assert input[3].content == "fake_input"
|
||||
|
|
|
@ -31,6 +31,26 @@ test_response_web_search:
|
|||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_custom_tool:
|
||||
test_name: test_response_custom_tool
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "sf_weather"
|
||||
input: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- type: function
|
||||
name: get_weather
|
||||
description: Get current temperature for a given location.
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
|
||||
test_response_image:
|
||||
test_name: test_response_image
|
||||
test_params:
|
||||
|
|
|
@ -124,6 +124,28 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
|||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_custom_tool(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
tools=case["tools"],
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_image"]["test_params"]["case"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue