forked from phoenix-oss/llama-stack-mirror
feat: Add temperature support to responses API (#2065)
# What does this PR do? Add support for the temperature to the responses API ## Test Plan Manually tested simple case unit tests added for simple case and tool calls Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
f36f68c590
commit
64829947d0
6 changed files with 220 additions and 3 deletions
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -6462,6 +6462,9 @@
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -4506,6 +4506,8 @@ components:
|
||||||
type: boolean
|
type: boolean
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
temperature:
|
||||||
|
type: number
|
||||||
tools:
|
tools:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
@ -628,6 +628,7 @@ class Agents(Protocol):
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: Optional[str] = None,
|
||||||
store: Optional[bool] = True,
|
store: Optional[bool] = True,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
||||||
"""Create a new OpenAI response.
|
"""Create a new OpenAI response.
|
||||||
|
|
|
@ -270,8 +270,9 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: Optional[str] = None,
|
||||||
store: Optional[bool] = True,
|
store: Optional[bool] = True,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
return await self.openai_responses_impl.create_openai_response(
|
||||||
input, model, previous_response_id, store, stream, tools
|
input, model, previous_response_id, store, stream, temperature, tools
|
||||||
)
|
)
|
||||||
|
|
|
@ -106,6 +106,7 @@ class OpenAIResponsesImpl:
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: Optional[str] = None,
|
||||||
store: Optional[bool] = True,
|
store: Optional[bool] = True,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
@ -141,6 +142,7 @@ class OpenAIResponsesImpl:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=chat_tools,
|
tools=chat_tools,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -180,7 +182,7 @@ class OpenAIResponsesImpl:
|
||||||
output_messages: List[OpenAIResponseOutput] = []
|
output_messages: List[OpenAIResponseOutput] = []
|
||||||
if chat_response.choices[0].message.tool_calls:
|
if chat_response.choices[0].message.tool_calls:
|
||||||
output_messages.extend(
|
output_messages.extend(
|
||||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages)
|
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||||
|
@ -241,7 +243,12 @@ class OpenAIResponsesImpl:
|
||||||
return chat_tools
|
return chat_tools
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
async def _execute_tool_and_return_final_output(
|
||||||
self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
|
self,
|
||||||
|
model_id: str,
|
||||||
|
stream: bool,
|
||||||
|
chat_response: OpenAIChatCompletion,
|
||||||
|
messages: List[OpenAIMessageParam],
|
||||||
|
temperature: float,
|
||||||
) -> List[OpenAIResponseOutput]:
|
) -> List[OpenAIResponseOutput]:
|
||||||
output_messages: List[OpenAIResponseOutput] = []
|
output_messages: List[OpenAIResponseOutput] = []
|
||||||
choice = chat_response.choices[0]
|
choice = chat_response.choices[0]
|
||||||
|
@ -295,6 +302,7 @@ class OpenAIResponsesImpl:
|
||||||
model=model_id,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
# type cast to appease mypy
|
# type cast to appease mypy
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputToolWebSearch,
|
||||||
|
OpenAIResponseOutputMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||||
|
OpenAIResponsesImpl,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kvstore():
|
||||||
|
kvstore = AsyncMock(spec=KVStore)
|
||||||
|
return kvstore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_inference_api():
|
||||||
|
inference_api = AsyncMock()
|
||||||
|
return inference_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_groups_api():
|
||||||
|
tool_groups_api = AsyncMock(spec=ToolGroups)
|
||||||
|
return tool_groups_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_runtime_api():
|
||||||
|
tool_runtime_api = AsyncMock(spec=ToolRuntime)
|
||||||
|
return tool_runtime_api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api):
|
||||||
|
return OpenAIResponsesImpl(
|
||||||
|
persistence_store=mock_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
tool_groups_api=mock_tool_groups_api,
|
||||||
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||||
|
"""Test creating an OpenAI response with a simple string input."""
|
||||||
|
# Setup
|
||||||
|
input_text = "Hello, world!"
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
mock_chat_completion = OpenAIChatCompletion(
|
||||||
|
id="chat-completion-123",
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=OpenAIAssistantMessageParam(content="Hello! How can I help you?"),
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1234567890,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
input=input_text,
|
||||||
|
model=model,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||||
|
model=model,
|
||||||
|
messages=[OpenAIUserMessageParam(role="user", content="Hello, world!", name=None)],
|
||||||
|
tools=None,
|
||||||
|
stream=False,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||||
|
assert result.model == model
|
||||||
|
assert len(result.output) == 1
|
||||||
|
assert isinstance(result.output[0], OpenAIResponseOutputMessage)
|
||||||
|
assert result.output[0].content[0].text == "Hello! How can I help you?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||||
|
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||||
|
# Setup
|
||||||
|
input_text = "What was the score of todays game?"
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
mock_chat_completions = [
|
||||||
|
OpenAIChatCompletion(
|
||||||
|
id="chat-completion-123",
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=OpenAIAssistantMessageParam(
|
||||||
|
tool_calls=[
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
id="tool_call_123",
|
||||||
|
type="function",
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name="web_search", arguments='{"query":"What was the score of todays game?"}'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1234567890,
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
OpenAIChatCompletion(
|
||||||
|
id="chat-completion-123",
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=OpenAIAssistantMessageParam(content="The score of todays game was 10-12"),
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1234567890,
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_inference_api.openai_chat_completion.side_effect = mock_chat_completions
|
||||||
|
|
||||||
|
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||||
|
identifier="web_search",
|
||||||
|
provider_id="client",
|
||||||
|
toolgroup_id="web_search",
|
||||||
|
tool_host="client",
|
||||||
|
description="Search the web for information",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(
|
||||||
|
status="completed",
|
||||||
|
content="The score of todays game was 10-12",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
input=input_text,
|
||||||
|
model=model,
|
||||||
|
temperature=0.1,
|
||||||
|
tools=[
|
||||||
|
OpenAIResponseInputToolWebSearch(
|
||||||
|
name="web_search",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
|
assert first_call.kwargs["messages"][0].content == "What was the score of todays game?"
|
||||||
|
assert first_call.kwargs["tools"] is not None
|
||||||
|
assert first_call.kwargs["temperature"] == 0.1
|
||||||
|
|
||||||
|
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||||
|
assert second_call.kwargs["messages"][-1].content == "The score of todays game was 10-12"
|
||||||
|
assert second_call.kwargs["temperature"] == 0.1
|
||||||
|
|
||||||
|
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||||
|
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||||
|
tool_name="web_search",
|
||||||
|
kwargs={"query": "What was the score of todays game?"},
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_responses_impl.persistence_store.set.assert_called_once()
|
||||||
|
|
||||||
|
# Check that we got the content from our mocked tool execution result
|
||||||
|
assert len(result.output) >= 1
|
||||||
|
assert isinstance(result.output[1], OpenAIResponseOutputMessage)
|
||||||
|
assert result.output[1].content[0].text == "The score of todays game was 10-12"
|
Loading…
Add table
Add a link
Reference in a new issue