feat: Add temperature support to responses API (#2065)

# What does this PR do?
Add support for the temperature to the responses API 


## Test Plan
Manually tested simple case
unit tests added for simple case and tool calls

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-05-01 19:47:58 +01:00 committed by GitHub
parent f36f68c590
commit 64829947d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 220 additions and 3 deletions

View file

@ -0,0 +1,202 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolWebSearch,
OpenAIResponseOutputMessage,
)
from llama_stack.apis.inference.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.kvstore import KVStore
@pytest.fixture
def mock_kvstore():
kvstore = AsyncMock(spec=KVStore)
return kvstore
@pytest.fixture
def mock_inference_api():
inference_api = AsyncMock()
return inference_api
@pytest.fixture
def mock_tool_groups_api():
tool_groups_api = AsyncMock(spec=ToolGroups)
return tool_groups_api
@pytest.fixture
def mock_tool_runtime_api():
tool_runtime_api = AsyncMock(spec=ToolRuntime)
return tool_runtime_api
@pytest.fixture
def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api):
return OpenAIResponsesImpl(
persistence_store=mock_kvstore,
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
)
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input."""
# Setup
input_text = "Hello, world!"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = OpenAIChatCompletion(
id="chat-completion-123",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Hello! How can I help you?"),
finish_reason="stop",
index=0,
)
],
created=1234567890,
model=model,
)
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
temperature=0.1,
)
# Verify
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="Hello, world!", name=None)],
tools=None,
stream=False,
temperature=0.1,
)
openai_responses_impl.persistence_store.set.assert_called_once()
assert result.model == model
assert len(result.output) == 1
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
assert result.output[0].content[0].text == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input and tools."""
# Setup
input_text = "What was the score of todays game?"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completions = [
OpenAIChatCompletion(
id="chat-completion-123",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
tool_calls=[
OpenAIChatCompletionToolCall(
id="tool_call_123",
type="function",
function=OpenAIChatCompletionToolCallFunction(
name="web_search", arguments='{"query":"What was the score of todays game?"}'
),
)
],
),
finish_reason="stop",
index=0,
)
],
created=1234567890,
model=model,
),
OpenAIChatCompletion(
id="chat-completion-123",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(content="The score of todays game was 10-12"),
finish_reason="stop",
index=0,
)
],
created=1234567890,
model=model,
),
]
mock_inference_api.openai_chat_completion.side_effect = mock_chat_completions
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
identifier="web_search",
provider_id="client",
toolgroup_id="web_search",
tool_host="client",
description="Search the web for information",
parameters=[
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
],
)
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
status="completed",
content="The score of todays game was 10-12",
)
# Execute
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
temperature=0.1,
tools=[
OpenAIResponseInputToolWebSearch(
name="web_search",
)
],
)
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What was the score of todays game?"
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "The score of todays game was 10-12"
assert second_call.kwargs["temperature"] == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
tool_name="web_search",
kwargs={"query": "What was the score of todays game?"},
)
openai_responses_impl.persistence_store.set.assert_called_once()
# Check that we got the content from our mocked tool execution result
assert len(result.output) >= 1
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
assert result.output[1].content[0].text == "The score of todays game was 10-12"