OpenAI Responses - image support and multi-turn tool calling

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-18 09:13:48 -04:00 committed by Ashwin Bharambe
parent 35b2e2646f
commit d523c8692a
13 changed files with 186 additions and 34 deletions

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::together`, `inline::sentence-transformers` | | inference | `remote::together`, `inline::sentence-transformers` |
| openai_responses | `inline::openai-responses` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |

View file

@ -80,6 +80,35 @@ class OpenAIResponseObjectStream(BaseModel):
type: Literal["response.created"] = "response.created" type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel):
text: str
type: Literal["input_text"] = "input_text"
@json_schema_type
class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low", "high", "auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: Optional[str] = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type
class OpenAIResponseInputMessage(BaseModel):
content: Union[str, List[OpenAIResponseInputMessageContent]]
role: Literal["system", "developer", "user", "assistant"]
type: Optional[Literal["message"]] = "message"
@json_schema_type @json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel): class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search" type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search"
@ -109,7 +138,7 @@ class OpenAIResponses(Protocol):
@webmethod(route="/openai/v1/responses", method="POST") @webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response( async def create_openai_response(
self, self,
input: str, input: Union[str, List[OpenAIResponseInputMessage]],
model: str, model: str,
previous_response_id: Optional[str] = None, previous_response_id: Optional[str] = None,
store: Optional[bool] = True, store: Optional[bool] = True,

View file

@ -6,7 +6,7 @@
import json import json
import uuid import uuid
from typing import AsyncIterator, List, Optional, cast from typing import AsyncIterator, List, Optional, Union, cast
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
@ -14,9 +14,12 @@ from llama_stack.apis.inference.inference import (
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCallFunction, OpenAIChatCompletionToolCallFunction,
OpenAIChoice, OpenAIChoice,
OpenAIImageURL,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
@ -24,6 +27,9 @@ from llama_stack.apis.inference.inference import (
from llama_stack.apis.models.models import Models, ModelType from llama_stack.apis.models.models import Models, ModelType
from llama_stack.apis.openai_responses import OpenAIResponses from llama_stack.apis.openai_responses import OpenAIResponses
from llama_stack.apis.openai_responses.openai_responses import ( from llama_stack.apis.openai_responses.openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
@ -106,13 +112,14 @@ class OpenAIResponsesImpl(OpenAIResponses):
async def create_openai_response( async def create_openai_response(
self, self,
input: str, input: Union[str, List[OpenAIResponseInputMessage]],
model: str, model: str,
previous_response_id: Optional[str] = None, previous_response_id: Optional[str] = None,
store: Optional[bool] = True, store: Optional[bool] = True,
stream: Optional[bool] = False, stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None, tools: Optional[List[OpenAIResponseInputTool]] = None,
): ):
stream = False if stream is None else stream
model_obj = await self.models_api.get_model(model) model_obj = await self.models_api.get_model(model)
if model_obj is None: if model_obj is None:
raise ValueError(f"Model '{model}' not found") raise ValueError(f"Model '{model}' not found")
@ -123,13 +130,34 @@ class OpenAIResponsesImpl(OpenAIResponses):
if previous_response_id: if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id) previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response)) messages.extend(await _previous_response_to_messages(previous_response))
messages.append(OpenAIUserMessageParam(content=input)) # TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
# TODO: the code below doesn't handle streaming
chat_response = await self.inference_api.openai_chat_completion( chat_response = await self.inference_api.openai_chat_completion(
model=model_obj.identifier, model=model_obj.identifier,
messages=messages, messages=messages,
tools=chat_tools, tools=chat_tools,
stream=stream,
) )
# type cast to appease mypy # type cast to appease mypy
chat_response = cast(OpenAIChatCompletion, chat_response) chat_response = cast(OpenAIChatCompletion, chat_response)
@ -139,7 +167,7 @@ class OpenAIResponsesImpl(OpenAIResponses):
output_messages: List[OpenAIResponseOutput] = [] output_messages: List[OpenAIResponseOutput] = []
if chat_response.choices[0].finish_reason == "tool_calls": if chat_response.choices[0].finish_reason == "tool_calls":
output_messages.extend( output_messages.extend(
await self._execute_tool_and_return_final_output(model_obj.identifier, chat_response, messages) await self._execute_tool_and_return_final_output(model_obj.identifier, stream, chat_response, messages)
) )
else: else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices)) output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
@ -198,7 +226,7 @@ class OpenAIResponsesImpl(OpenAIResponses):
return chat_tools return chat_tools
async def _execute_tool_and_return_final_output( async def _execute_tool_and_return_final_output(
self, model_id: str, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam] self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
) -> List[OpenAIResponseOutput]: ) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = [] output_messages: List[OpenAIResponseOutput] = []
choice = chat_response.choices[0] choice = chat_response.choices[0]
@ -211,21 +239,21 @@ class OpenAIResponsesImpl(OpenAIResponses):
if not choice.message.tool_calls: if not choice.message.tool_calls:
return output_messages return output_messages
# TODO: handle multiple tool calls # Add the assistant message with tool_calls response to the messages list
function = choice.message.tool_calls[0].function messages.append(choice.message)
# If the tool call is not a function, we don't need to execute it # TODO: handle multiple tool calls
if not function: tool_call = choice.message.tool_calls[0]
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
return output_messages return output_messages
# TODO: telemetry spans for tool calls # TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function) result = await self._execute_tool_call(function)
tool_call_prefix = "tc_"
if function.name == "web_search":
tool_call_prefix = "ws_"
tool_call_id = f"{tool_call_prefix}{uuid.uuid4()}"
# Handle tool call failure # Handle tool call failure
if not result: if not result:
output_messages.append( output_messages.append(
@ -251,6 +279,7 @@ class OpenAIResponsesImpl(OpenAIResponses):
tool_results_chat_response = await self.inference_api.openai_chat_completion( tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id, model=model_id,
messages=messages, messages=messages,
stream=stream,
) )
# type cast to appease mypy # type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)

View file

@ -24,6 +24,8 @@ distribution_spec:
- inline::braintrust - inline::braintrust
telemetry: telemetry:
- inline::meta-reference - inline::meta-reference
openai_responses:
- inline::openai-responses
tool_runtime: tool_runtime:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search
@ -31,6 +33,4 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha - remote::wolfram-alpha
openai_responses:
- inline::openai-responses
image_type: conda image_type: conda

View file

@ -92,6 +92,14 @@ providers:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}" service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search
@ -116,14 +124,6 @@ providers:
provider_type: remote::wolfram-alpha provider_type: remote::wolfram-alpha
config: config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:} api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db

View file

@ -85,6 +85,14 @@ providers:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}" service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search
@ -109,14 +117,6 @@ providers:
provider_type: remote::wolfram-alpha provider_type: remote::wolfram-alpha
config: config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:} api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db

View file

@ -31,6 +31,7 @@ def get_distribution_template() -> DistributionTemplate:
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"openai_responses": ["inline::openai-responses"],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
@ -39,7 +40,6 @@ def get_distribution_template() -> DistributionTemplate:
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha", "remote::wolfram-alpha",
], ],
"openai_responses": ["inline::openai-responses"],
} }
name = "remote-vllm" name = "remote-vllm"
inference_provider = Provider( inference_provider = Provider(

View file

@ -24,6 +24,8 @@ distribution_spec:
- inline::basic - inline::basic
- inline::llm-as-judge - inline::llm-as-judge
- inline::braintrust - inline::braintrust
openai_responses:
- inline::openai-responses
tool_runtime: tool_runtime:
- remote::brave-search - remote::brave-search
- remote::tavily-search - remote::tavily-search

View file

@ -5,6 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- openai_responses
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -87,6 +88,14 @@ providers:
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/openai_responses.db
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search

View file

@ -5,6 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- openai_responses
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -82,6 +83,14 @@ providers:
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
openai_api_key: ${env.OPENAI_API_KEY:} openai_api_key: ${env.OPENAI_API_KEY:}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/openai_responses.db
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search

View file

@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"openai_responses": ["inline::openai-responses"],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",

View file

@ -36,3 +36,66 @@ def test_web_search_non_streaming(openai_client, client_with_models, text_model_
assert response.output[1].role == "assistant" assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0 assert len(response.output[1].content) > 0
assert expected.lower() in response.output_text.lower().strip() assert expected.lower() in response.output_text.lower().strip()
def test_input_image_non_streaming(openai_client, vision_model_id):
supported_models = ["llama-4", "gpt-4o", "llama4"]
if not any(model in vision_model_id.lower() for model in supported_models):
pytest.skip(f"Skip for non-supported model: {vision_model_id}")
response = openai_client.with_options(max_retries=0).responses.create(
model=vision_model_id,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "Identify the type of animal in this image.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
}
],
)
output_text = response.output_text.lower()
assert "llama" in output_text
def test_multi_turn_web_search_from_image_non_streaming(openai_client, vision_model_id):
supported_models = ["llama-4", "gpt-4o", "llama4"]
if not any(model in vision_model_id.lower() for model in supported_models):
pytest.skip(f"Skip for non-supported model: {vision_model_id}")
response = openai_client.with_options(max_retries=0).responses.create(
model=vision_model_id,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "Extract a single search keyword that represents the type of animal in this image.",
},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg",
},
],
}
],
)
output_text = response.output_text.lower()
assert "llama" in output_text
search_response = openai_client.with_options(max_retries=0).responses.create(
model=vision_model_id,
input="Search the web using the search tool for those keywords plus the words 'maverick' and 'scout' and summarize the results.",
previous_response_id=response.id,
tools=[{"type": "web_search"}],
)
output_text = search_response.output_text.lower()
assert "model" in output_text

View file

@ -2,6 +2,7 @@ version: '2'
image_name: openai-api-verification image_name: openai-api-verification
apis: apis:
- inference - inference
- openai_responses
- telemetry - telemetry
- tool_runtime - tool_runtime
- vector_io - vector_io
@ -45,6 +46,14 @@ providers:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}" service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite} sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
openai_responses:
- provider_id: openai-responses
provider_type: inline::openai-responses
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/openai_responses.db
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_id: brave-search
provider_type: remote::brave-search provider_type: remote::brave-search