mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
chore: Refactor OpenAIChatCompletion's to be loaded from yaml
Future tests can then re-use the content Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
fe5f5e530c
commit
1369b5858e
4 changed files with 114 additions and 63 deletions
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion:
|
||||||
|
"""
|
||||||
|
Load a YAML fixture file and convert it to an OpenAIChatCompletion object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the YAML file (without path)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAIChatCompletion object
|
||||||
|
"""
|
||||||
|
fixtures_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
fixture_path = os.path.join(fixtures_dir, filename)
|
||||||
|
|
||||||
|
with open(fixture_path) as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
choices = []
|
||||||
|
for choice_data in data.get("choices", []):
|
||||||
|
message_data = choice_data.get("message", {})
|
||||||
|
|
||||||
|
# Handle tool calls if present
|
||||||
|
tool_calls = None
|
||||||
|
if "tool_calls" in message_data:
|
||||||
|
tool_calls = []
|
||||||
|
for tool_call_data in message_data.get("tool_calls", []):
|
||||||
|
function_data = tool_call_data.get("function", {})
|
||||||
|
function = OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=function_data.get("name"),
|
||||||
|
arguments=function_data.get("arguments"),
|
||||||
|
)
|
||||||
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
|
id=tool_call_data.get("id"),
|
||||||
|
type=tool_call_data.get("type"),
|
||||||
|
function=function,
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
message = OpenAIAssistantMessageParam(
|
||||||
|
content=message_data.get("content"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
choice = OpenAIChoice(
|
||||||
|
message=message,
|
||||||
|
finish_reason=choice_data.get("finish_reason"),
|
||||||
|
index=choice_data.get("index", 0),
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
|
||||||
|
return OpenAIChatCompletion(
|
||||||
|
id=data.get("id"),
|
||||||
|
choices=choices,
|
||||||
|
created=data.get("created"),
|
||||||
|
model=data.get("model"),
|
||||||
|
)
|
|
@ -0,0 +1,8 @@
|
||||||
|
id: chat-completion-123
|
||||||
|
choices:
|
||||||
|
- message:
|
||||||
|
content: "Dublin"
|
||||||
|
finish_reason: stop
|
||||||
|
index: 0
|
||||||
|
created: 1234567890
|
||||||
|
model: meta-llama/Llama-3.1-8B-Instruct
|
|
@ -0,0 +1,13 @@
|
||||||
|
id: chat-completion-123
|
||||||
|
choices:
|
||||||
|
- message:
|
||||||
|
tool_calls:
|
||||||
|
- id: tool_call_123
|
||||||
|
type: function
|
||||||
|
function:
|
||||||
|
name: web_search
|
||||||
|
arguments: '{"query":"What is the capital of Ireland?"}'
|
||||||
|
finish_reason: stop
|
||||||
|
index: 0
|
||||||
|
created: 1234567890
|
||||||
|
model: meta-llama/Llama-3.1-8B-Instruct
|
|
@ -13,11 +13,6 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessage,
|
OpenAIResponseOutputMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionToolCall,
|
|
||||||
OpenAIChatCompletionToolCallFunction,
|
|
||||||
OpenAIChoice,
|
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||||
|
@ -25,6 +20,7 @@ from llama_stack.providers.inline.agents.meta_reference.openai_responses import
|
||||||
OpenAIResponsesImpl,
|
OpenAIResponsesImpl,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -65,21 +61,11 @@ def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api
|
||||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a simple string input."""
|
"""Test creating an OpenAI response with a simple string input."""
|
||||||
# Setup
|
# Setup
|
||||||
input_text = "Hello, world!"
|
input_text = "What is the capital of Ireland?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
mock_chat_completion = OpenAIChatCompletion(
|
# Load the chat completion fixture
|
||||||
id="chat-completion-123",
|
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||||
choices=[
|
|
||||||
OpenAIChoice(
|
|
||||||
message=OpenAIAssistantMessageParam(content="Hello! How can I help you?"),
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1234567890,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
|
@ -92,7 +78,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
# Verify
|
# Verify
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[OpenAIUserMessageParam(role="user", content="Hello, world!", name=None)],
|
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||||
tools=None,
|
tools=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
|
@ -101,54 +87,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
assert result.model == model
|
assert result.model == model
|
||||||
assert len(result.output) == 1
|
assert len(result.output) == 1
|
||||||
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
|
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
|
||||||
assert result.output[0].content[0].text == "Hello! How can I help you?"
|
assert result.output[0].content[0].text == "Dublin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a simple string input and tools."""
|
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||||
# Setup
|
# Setup
|
||||||
input_text = "What was the score of todays game?"
|
input_text = "What is the capital of Ireland?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
mock_chat_completions = [
|
# Load the chat completion fixtures
|
||||||
OpenAIChatCompletion(
|
tool_call_completion = load_chat_completion_fixture("tool_call_completion.yaml")
|
||||||
id="chat-completion-123",
|
tool_response_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||||
choices=[
|
|
||||||
OpenAIChoice(
|
|
||||||
message=OpenAIAssistantMessageParam(
|
|
||||||
tool_calls=[
|
|
||||||
OpenAIChatCompletionToolCall(
|
|
||||||
id="tool_call_123",
|
|
||||||
type="function",
|
|
||||||
function=OpenAIChatCompletionToolCallFunction(
|
|
||||||
name="web_search", arguments='{"query":"What was the score of todays game?"}'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1234567890,
|
|
||||||
model=model,
|
|
||||||
),
|
|
||||||
OpenAIChatCompletion(
|
|
||||||
id="chat-completion-123",
|
|
||||||
choices=[
|
|
||||||
OpenAIChoice(
|
|
||||||
message=OpenAIAssistantMessageParam(content="The score of todays game was 10-12"),
|
|
||||||
finish_reason="stop",
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=1234567890,
|
|
||||||
model=model,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.side_effect = mock_chat_completions
|
mock_inference_api.openai_chat_completion.side_effect = [
|
||||||
|
tool_call_completion,
|
||||||
|
tool_response_completion,
|
||||||
|
]
|
||||||
|
|
||||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||||
identifier="web_search",
|
identifier="web_search",
|
||||||
|
@ -163,7 +119,7 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
|
|
||||||
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
||||||
status="completed",
|
status="completed",
|
||||||
content="The score of todays game was 10-12",
|
content="Dublin",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
|
@ -180,18 +136,18 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
assert first_call.kwargs["messages"][0].content == "What was the score of todays game?"
|
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||||
assert first_call.kwargs["tools"] is not None
|
assert first_call.kwargs["tools"] is not None
|
||||||
assert first_call.kwargs["temperature"] == 0.1
|
assert first_call.kwargs["temperature"] == 0.1
|
||||||
|
|
||||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||||
assert second_call.kwargs["messages"][-1].content == "The score of todays game was 10-12"
|
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||||
assert second_call.kwargs["temperature"] == 0.1
|
assert second_call.kwargs["temperature"] == 0.1
|
||||||
|
|
||||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||||
tool_name="web_search",
|
tool_name="web_search",
|
||||||
kwargs={"query": "What was the score of todays game?"},
|
kwargs={"query": "What is the capital of Ireland?"},
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_responses_impl.persistence_store.set.assert_called_once()
|
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||||
|
@ -199,4 +155,4 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
# Check that we got the content from our mocked tool execution result
|
# Check that we got the content from our mocked tool execution result
|
||||||
assert len(result.output) >= 1
|
assert len(result.output) >= 1
|
||||||
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
||||||
assert result.output[1].content[0].text == "The score of todays game was 10-12"
|
assert result.output[1].content[0].text == "Dublin"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue