From d523c8692a6c17162619cf6808f96314857c8939 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Fri, 18 Apr 2025 09:13:48 -0400 Subject: [PATCH] OpenAI Responses - image support and multi-turn tool calling Signed-off-by: Ben Browning --- .../self_hosted_distro/together.md | 1 + .../apis/openai_responses/openai_responses.py | 31 ++++++++- .../openai_responses/openai_responses.py | 57 ++++++++++++----- llama_stack/templates/remote-vllm/build.yaml | 4 +- .../remote-vllm/run-with-safety.yaml | 16 ++--- llama_stack/templates/remote-vllm/run.yaml | 16 ++--- llama_stack/templates/remote-vllm/vllm.py | 2 +- llama_stack/templates/together/build.yaml | 2 + .../templates/together/run-with-safety.yaml | 9 +++ llama_stack/templates/together/run.yaml | 9 +++ llama_stack/templates/together/together.py | 1 + .../test_web_search_builtin.py | 63 +++++++++++++++++++ .../openai-api-verification-run.yaml | 9 +++ 13 files changed, 186 insertions(+), 34 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/together.md b/docs/source/distributions/self_hosted_distro/together.md index 3ebb1f59e..5da0ee980 100644 --- a/docs/source/distributions/self_hosted_distro/together.md +++ b/docs/source/distributions/self_hosted_distro/together.md @@ -19,6 +19,7 @@ The `llamastack/distribution-together` distribution consists of the following pr | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | | inference | `remote::together`, `inline::sentence-transformers` | +| openai_responses | `inline::openai-responses` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/apis/openai_responses/openai_responses.py b/llama_stack/apis/openai_responses/openai_responses.py index 4741f2502..87ccfdabd 100644 --- a/llama_stack/apis/openai_responses/openai_responses.py +++ b/llama_stack/apis/openai_responses/openai_responses.py @@ -80,6 +80,35 @@ class OpenAIResponseObjectStream(BaseModel): type: Literal["response.created"] = "response.created" +@json_schema_type +class OpenAIResponseInputMessageContentText(BaseModel): + text: str + type: Literal["input_text"] = "input_text" + + +@json_schema_type +class OpenAIResponseInputMessageContentImage(BaseModel): + detail: Literal["low", "high", "auto"] = "auto" + type: Literal["input_image"] = "input_image" + # TODO: handle file_id + image_url: Optional[str] = None + + +# TODO: handle file content types +OpenAIResponseInputMessageContent = Annotated[ + Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage], + Field(discriminator="type"), +] +register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") + + +@json_schema_type +class OpenAIResponseInputMessage(BaseModel): + content: Union[str, List[OpenAIResponseInputMessageContent]] + role: Literal["system", "developer", "user", "assistant"] + type: Optional[Literal["message"]] = "message" + + @json_schema_type class OpenAIResponseInputToolWebSearch(BaseModel): type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search" @@ -109,7 +138,7 @@ class OpenAIResponses(Protocol): @webmethod(route="/openai/v1/responses", method="POST") async def create_openai_response( self, - input: str, + input: Union[str, List[OpenAIResponseInputMessage]], model: str, previous_response_id: Optional[str] = None, store: Optional[bool] = True, diff --git a/llama_stack/providers/inline/openai_responses/openai_responses.py b/llama_stack/providers/inline/openai_responses/openai_responses.py index 2825fed95..2a137e5c1 100644 --- a/llama_stack/providers/inline/openai_responses/openai_responses.py +++ b/llama_stack/providers/inline/openai_responses/openai_responses.py @@ -6,7 +6,7 @@ import json import uuid -from typing import AsyncIterator, List, Optional, cast +from typing import AsyncIterator, List, Optional, Union, cast from openai.types.chat import ChatCompletionToolParam @@ -14,9 +14,12 @@ from llama_stack.apis.inference.inference import ( Inference, OpenAIAssistantMessageParam, OpenAIChatCompletion, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartParam, OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionToolCallFunction, OpenAIChoice, + OpenAIImageURL, OpenAIMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, @@ -24,6 +27,9 @@ from llama_stack.apis.inference.inference import ( from llama_stack.apis.models.models import Models, ModelType from llama_stack.apis.openai_responses import OpenAIResponses from llama_stack.apis.openai_responses.openai_responses import ( + OpenAIResponseInputMessage, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, OpenAIResponseObject, OpenAIResponseObjectStream, @@ -106,13 +112,14 @@ class OpenAIResponsesImpl(OpenAIResponses): async def create_openai_response( self, - input: str, + input: Union[str, List[OpenAIResponseInputMessage]], model: str, previous_response_id: Optional[str] = None, store: Optional[bool] = True, stream: Optional[bool] = False, tools: Optional[List[OpenAIResponseInputTool]] = None, ): + stream = False if stream is None else stream model_obj = await self.models_api.get_model(model) if model_obj is None: raise ValueError(f"Model '{model}' not found") @@ -123,13 +130,34 @@ class OpenAIResponsesImpl(OpenAIResponses): if previous_response_id: previous_response = await self.get_openai_response(previous_response_id) messages.extend(await _previous_response_to_messages(previous_response)) - messages.append(OpenAIUserMessageParam(content=input)) + # TODO: refactor this user_content parsing out into a separate method + user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = "" + if isinstance(input, list): + user_content = [] + for user_input in input: + if isinstance(user_input.content, list): + for user_input_content in user_input.content: + if isinstance(user_input_content, OpenAIResponseInputMessageContentText): + user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text)) + elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage): + if user_input_content.image_url: + image_url = OpenAIImageURL( + url=user_input_content.image_url, detail=user_input_content.detail + ) + user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + else: + user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content)) + else: + user_content = input + messages.append(OpenAIUserMessageParam(content=user_content)) chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None + # TODO: the code below doesn't handle streaming chat_response = await self.inference_api.openai_chat_completion( model=model_obj.identifier, messages=messages, tools=chat_tools, + stream=stream, ) # type cast to appease mypy chat_response = cast(OpenAIChatCompletion, chat_response) @@ -139,7 +167,7 @@ class OpenAIResponsesImpl(OpenAIResponses): output_messages: List[OpenAIResponseOutput] = [] if chat_response.choices[0].finish_reason == "tool_calls": output_messages.extend( - await self._execute_tool_and_return_final_output(model_obj.identifier, chat_response, messages) + await self._execute_tool_and_return_final_output(model_obj.identifier, stream, chat_response, messages) ) else: output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices)) @@ -198,7 +226,7 @@ class OpenAIResponsesImpl(OpenAIResponses): return chat_tools async def _execute_tool_and_return_final_output( - self, model_id: str, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam] + self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam] ) -> List[OpenAIResponseOutput]: output_messages: List[OpenAIResponseOutput] = [] choice = chat_response.choices[0] @@ -211,21 +239,21 @@ class OpenAIResponsesImpl(OpenAIResponses): if not choice.message.tool_calls: return output_messages - # TODO: handle multiple tool calls - function = choice.message.tool_calls[0].function + # Add the assistant message with tool_calls response to the messages list + messages.append(choice.message) - # If the tool call is not a function, we don't need to execute it - if not function: + # TODO: handle multiple tool calls + tool_call = choice.message.tool_calls[0] + tool_call_id = tool_call.id + function = tool_call.function + + # If for some reason the tool call doesn't have a function or id, we can't execute it + if not function or not tool_call_id: return output_messages # TODO: telemetry spans for tool calls result = await self._execute_tool_call(function) - tool_call_prefix = "tc_" - if function.name == "web_search": - tool_call_prefix = "ws_" - tool_call_id = f"{tool_call_prefix}{uuid.uuid4()}" - # Handle tool call failure if not result: output_messages.append( @@ -251,6 +279,7 @@ class OpenAIResponsesImpl(OpenAIResponses): tool_results_chat_response = await self.inference_api.openai_chat_completion( model=model_id, messages=messages, + stream=stream, ) # type cast to appease mypy tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 94326e570..b344f5e5a 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -24,6 +24,8 @@ distribution_spec: - inline::braintrust telemetry: - inline::meta-reference + openai_responses: + - inline::openai-responses tool_runtime: - remote::brave-search - remote::tavily-search @@ -31,6 +33,4 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol - remote::wolfram-alpha - openai_responses: - - inline::openai-responses image_type: conda diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index c53228ed4..a58417714 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -92,6 +92,14 @@ providers: service_name: "${env.OTEL_SERVICE_NAME:\u200B}" sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} + openai_responses: + - provider_id: openai-responses + provider_type: inline::openai-responses + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search @@ -116,14 +124,6 @@ providers: provider_type: remote::wolfram-alpha config: api_key: ${env.WOLFRAM_ALPHA_API_KEY:} - openai_responses: - - provider_id: openai-responses - provider_type: inline::openai-responses - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 282749d0d..58087bba3 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -85,6 +85,14 @@ providers: service_name: "${env.OTEL_SERVICE_NAME:\u200B}" sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} + openai_responses: + - provider_id: openai-responses + provider_type: inline::openai-responses + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search @@ -109,14 +117,6 @@ providers: provider_type: remote::wolfram-alpha config: api_key: ${env.WOLFRAM_ALPHA_API_KEY:} - openai_responses: - - provider_id: openai-responses - provider_type: inline::openai-responses - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/openai_responses.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 5cddb1c76..12515d1ad 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -31,6 +31,7 @@ def get_distribution_template() -> DistributionTemplate: "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "telemetry": ["inline::meta-reference"], + "openai_responses": ["inline::openai-responses"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", @@ -39,7 +40,6 @@ def get_distribution_template() -> DistributionTemplate: "remote::model-context-protocol", "remote::wolfram-alpha", ], - "openai_responses": ["inline::openai-responses"], } name = "remote-vllm" inference_provider = Provider( diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 834a3ecaf..81a47c5cd 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -24,6 +24,8 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + openai_responses: + - inline::openai-responses tool_runtime: - remote::brave-search - remote::tavily-search diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 105ce896d..fbeafce19 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- openai_responses - safety - scoring - telemetry @@ -87,6 +88,14 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + openai_responses: + - provider_id: openai-responses + provider_type: inline::openai-responses + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/openai_responses.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 1f1613655..0c5d82c13 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -5,6 +5,7 @@ apis: - datasetio - eval - inference +- openai_responses - safety - scoring - telemetry @@ -82,6 +83,14 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + openai_responses: + - provider_id: openai-responses + provider_type: inline::openai-responses + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/openai_responses.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index a2bd87c97..85b7645b3 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -36,6 +36,7 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "openai_responses": ["inline::openai-responses"], "tool_runtime": [ "remote::brave-search", "remote::tavily-search", diff --git a/tests/integration/openai_responses/test_web_search_builtin.py b/tests/integration/openai_responses/test_web_search_builtin.py index d7baf96d2..5f8a1afd7 100644 --- a/tests/integration/openai_responses/test_web_search_builtin.py +++ b/tests/integration/openai_responses/test_web_search_builtin.py @@ -36,3 +36,66 @@ def test_web_search_non_streaming(openai_client, client_with_models, text_model_ assert response.output[1].role == "assistant" assert len(response.output[1].content) > 0 assert expected.lower() in response.output_text.lower().strip() + + +def test_input_image_non_streaming(openai_client, vision_model_id): + supported_models = ["llama-4", "gpt-4o", "llama4"] + if not any(model in vision_model_id.lower() for model in supported_models): + pytest.skip(f"Skip for non-supported model: {vision_model_id}") + + response = openai_client.with_options(max_retries=0).responses.create( + model=vision_model_id, + input=[ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Identify the type of animal in this image.", + }, + { + "type": "input_image", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg", + }, + ], + } + ], + ) + output_text = response.output_text.lower() + assert "llama" in output_text + + +def test_multi_turn_web_search_from_image_non_streaming(openai_client, vision_model_id): + supported_models = ["llama-4", "gpt-4o", "llama4"] + if not any(model in vision_model_id.lower() for model in supported_models): + pytest.skip(f"Skip for non-supported model: {vision_model_id}") + + response = openai_client.with_options(max_retries=0).responses.create( + model=vision_model_id, + input=[ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Extract a single search keyword that represents the type of animal in this image.", + }, + { + "type": "input_image", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg", + }, + ], + } + ], + ) + output_text = response.output_text.lower() + assert "llama" in output_text + + search_response = openai_client.with_options(max_retries=0).responses.create( + model=vision_model_id, + input="Search the web using the search tool for those keywords plus the words 'maverick' and 'scout' and summarize the results.", + previous_response_id=response.id, + tools=[{"type": "web_search"}], + ) + output_text = search_response.output_text.lower() + assert "model" in output_text diff --git a/tests/verifications/openai-api-verification-run.yaml b/tests/verifications/openai-api-verification-run.yaml index 71885d058..58ce5344d 100644 --- a/tests/verifications/openai-api-verification-run.yaml +++ b/tests/verifications/openai-api-verification-run.yaml @@ -2,6 +2,7 @@ version: '2' image_name: openai-api-verification apis: - inference +- openai_responses - telemetry - tool_runtime - vector_io @@ -45,6 +46,14 @@ providers: service_name: "${env.OTEL_SERVICE_NAME:\u200B}" sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db} + openai_responses: + - provider_id: openai-responses + provider_type: inline::openai-responses + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/openai_responses.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search