From 35b2e2646f821756f5a2fd413524236702134613 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 17 Apr 2025 20:25:36 -0400 Subject: [PATCH] OpenAI Responses API: Stub in basic web_search tool --- .../apis/openai_responses/openai_responses.py | 31 +++- .../inline/openai_responses/__init__.py | 4 +- .../openai_responses/openai_responses.py | 143 +++++++++++++++++- .../providers/registry/openai_responses.py | 2 + tests/integration/fixtures/common.py | 7 + ...test_openai_responses.py => test_basic.py} | 11 +- .../test_web_search_builtin.py | 38 +++++ .../test_cases/openai/responses.json | 11 ++ 8 files changed, 232 insertions(+), 15 deletions(-) rename tests/integration/openai_responses/{test_openai_responses.py => test_basic.py} (86%) create mode 100644 tests/integration/openai_responses/test_web_search_builtin.py diff --git a/llama_stack/apis/openai_responses/openai_responses.py b/llama_stack/apis/openai_responses/openai_responses.py index c8324a13a..4741f2502 100644 --- a/llama_stack/apis/openai_responses/openai_responses.py +++ b/llama_stack/apis/openai_responses/openai_responses.py @@ -9,7 +9,7 @@ from typing import AsyncIterator, List, Literal, Optional, Protocol, Union, runt from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_stack.schema_utils import json_schema_type, webmethod +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type @@ -28,6 +28,7 @@ OpenAIResponseOutputMessageContent = Annotated[ Union[OpenAIResponseOutputMessageContentOutputText,], Field(discriminator="type"), ] +register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") @json_schema_type @@ -39,10 +40,21 @@ class OpenAIResponseOutputMessage(BaseModel): type: Literal["message"] = "message" +@json_schema_type +class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): + id: str + status: str + type: Literal["web_search_call"] = "web_search_call" + + OpenAIResponseOutput = Annotated[ - Union[OpenAIResponseOutputMessage,], + Union[ + OpenAIResponseOutputMessage, + OpenAIResponseOutputMessageWebSearchToolCall, + ], Field(discriminator="type"), ] +register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @json_schema_type @@ -68,6 +80,20 @@ class OpenAIResponseObjectStream(BaseModel): type: Literal["response.created"] = "response.created" +@json_schema_type +class OpenAIResponseInputToolWebSearch(BaseModel): + type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search" + search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$") + # TODO: add user_location + + +OpenAIResponseInputTool = Annotated[ + Union[OpenAIResponseInputToolWebSearch,], + Field(discriminator="type"), +] +register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") + + @runtime_checkable class OpenAIResponses(Protocol): """ @@ -88,4 +114,5 @@ class OpenAIResponses(Protocol): previous_response_id: Optional[str] = None, store: Optional[bool] = True, stream: Optional[bool] = False, + tools: Optional[List[OpenAIResponseInputTool]] = None, ) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ... diff --git a/llama_stack/providers/inline/openai_responses/__init__.py b/llama_stack/providers/inline/openai_responses/__init__.py index 6d461e81a..76f15d478 100644 --- a/llama_stack/providers/inline/openai_responses/__init__.py +++ b/llama_stack/providers/inline/openai_responses/__init__.py @@ -14,6 +14,8 @@ from .config import OpenAIResponsesImplConfig async def get_provider_impl(config: OpenAIResponsesImplConfig, deps: Dict[Api, Any]): from .openai_responses import OpenAIResponsesImpl - impl = OpenAIResponsesImpl(config, deps[Api.models], deps[Api.inference]) + impl = OpenAIResponsesImpl( + config, deps[Api.models], deps[Api.inference], deps[Api.tool_groups], deps[Api.tool_runtime] + ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/openai_responses/openai_responses.py b/llama_stack/providers/inline/openai_responses/openai_responses.py index bb26a10fa..2825fed95 100644 --- a/llama_stack/providers/inline/openai_responses/openai_responses.py +++ b/llama_stack/providers/inline/openai_responses/openai_responses.py @@ -4,27 +4,38 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import uuid from typing import AsyncIterator, List, Optional, cast +from openai.types.chat import ChatCompletionToolParam + from llama_stack.apis.inference.inference import ( Inference, OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCallFunction, OpenAIChoice, OpenAIMessageParam, + OpenAIToolMessageParam, OpenAIUserMessageParam, ) from llama_stack.apis.models.models import Models, ModelType from llama_stack.apis.openai_responses import OpenAIResponses from llama_stack.apis.openai_responses.openai_responses import ( + OpenAIResponseInputTool, OpenAIResponseObject, OpenAIResponseObjectStream, + OpenAIResponseOutput, OpenAIResponseOutputMessage, OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageWebSearchToolCall, ) +from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition +from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.kvstore import kvstore_impl from .config import OpenAIResponsesImplConfig @@ -37,7 +48,8 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:" async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]: messages: List[OpenAIMessageParam] = [] for output_message in previous_response.output: - messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) + if isinstance(output_message, OpenAIResponseOutputMessage): + messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) return messages @@ -61,10 +73,19 @@ async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> Lis class OpenAIResponsesImpl(OpenAIResponses): - def __init__(self, config: OpenAIResponsesImplConfig, models_api: Models, inference_api: Inference): + def __init__( + self, + config: OpenAIResponsesImplConfig, + models_api: Models, + inference_api: Inference, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + ): self.config = config self.models_api = models_api self.inference_api = inference_api + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -90,6 +111,7 @@ class OpenAIResponsesImpl(OpenAIResponses): previous_response_id: Optional[str] = None, store: Optional[bool] = True, stream: Optional[bool] = False, + tools: Optional[List[OpenAIResponseInputTool]] = None, ): model_obj = await self.models_api.get_model(model) if model_obj is None: @@ -103,14 +125,24 @@ class OpenAIResponsesImpl(OpenAIResponses): messages.extend(await _previous_response_to_messages(previous_response)) messages.append(OpenAIUserMessageParam(content=input)) + chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None chat_response = await self.inference_api.openai_chat_completion( model=model_obj.identifier, messages=messages, + tools=chat_tools, ) # type cast to appease mypy chat_response = cast(OpenAIChatCompletion, chat_response) + # dump and reload to map to our pydantic types + chat_response = OpenAIChatCompletion.model_validate_json(chat_response.model_dump_json()) - output_messages = await _openai_choices_to_output_messages(chat_response.choices) + output_messages: List[OpenAIResponseOutput] = [] + if chat_response.choices[0].finish_reason == "tool_calls": + output_messages.extend( + await self._execute_tool_and_return_final_output(model_obj.identifier, chat_response, messages) + ) + else: + output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices)) response = OpenAIResponseObject( created_at=chat_response.created, id=f"resp-{uuid.uuid4()}", @@ -136,3 +168,108 @@ class OpenAIResponsesImpl(OpenAIResponses): return async_response() return response + + async def _convert_response_tools_to_chat_tools( + self, tools: List[OpenAIResponseInputTool] + ) -> List[ChatCompletionToolParam]: + chat_tools: List[ChatCompletionToolParam] = [] + for input_tool in tools: + # TODO: Handle other tool types + if input_tool.type == "web_search": + tool_name = "web_search" + tool = await self.tool_groups_api.get_tool(tool_name) + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + chat_tool = convert_tooldef_to_openai_tool(tool_def) + chat_tools.append(chat_tool) + else: + raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") + return chat_tools + + async def _execute_tool_and_return_final_output( + self, model_id: str, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam] + ) -> List[OpenAIResponseOutput]: + output_messages: List[OpenAIResponseOutput] = [] + choice = chat_response.choices[0] + + # If the choice is not an assistant message, we don't need to execute any tools + if not isinstance(choice.message, OpenAIAssistantMessageParam): + return output_messages + + # If the assistant message doesn't have any tool calls, we don't need to execute any tools + if not choice.message.tool_calls: + return output_messages + + # TODO: handle multiple tool calls + function = choice.message.tool_calls[0].function + + # If the tool call is not a function, we don't need to execute it + if not function: + return output_messages + + # TODO: telemetry spans for tool calls + result = await self._execute_tool_call(function) + + tool_call_prefix = "tc_" + if function.name == "web_search": + tool_call_prefix = "ws_" + tool_call_id = f"{tool_call_prefix}{uuid.uuid4()}" + + # Handle tool call failure + if not result: + output_messages.append( + OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="failed", + ) + ) + return output_messages + + output_messages.append( + OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ), + ) + + result_content = "" + # TODO: handle other result content types and lists + if isinstance(result.content, str): + result_content = result.content + messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id)) + tool_results_chat_response = await self.inference_api.openai_chat_completion( + model=model_id, + messages=messages, + ) + # type cast to appease mypy + tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) + tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices) + # TODO: Wire in annotations with URLs, titles, etc to these output messages + output_messages.extend(tool_final_outputs) + return output_messages + + async def _execute_tool_call( + self, + function: OpenAIChatCompletionToolCallFunction, + ) -> Optional[ToolInvocationResult]: + if not function.name: + return None + function_args = json.loads(function.arguments) if function.arguments else {} + logger.info(f"executing tool call: {function.name} with args: {function_args}") + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=function_args, + ) + logger.debug(f"tool call {function.name} completed with result: {result}") + return result diff --git a/llama_stack/providers/registry/openai_responses.py b/llama_stack/providers/registry/openai_responses.py index dd60f19dc..b7f8d17a0 100644 --- a/llama_stack/providers/registry/openai_responses.py +++ b/llama_stack/providers/registry/openai_responses.py @@ -20,6 +20,8 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.models, Api.inference, + Api.tool_groups, + Api.tool_runtime, ], ), ] diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 1878c9e88..809a00897 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -14,6 +14,7 @@ from pathlib import Path import pytest import yaml from llama_stack_client import LlamaStackClient +from openai import OpenAI from llama_stack import LlamaStackAsLibraryClient from llama_stack.apis.datatypes import Api @@ -207,3 +208,9 @@ def llama_stack_client(request, provider_data, text_model_id): raise RuntimeError("Initialization failed") return client + + +@pytest.fixture(scope="session") +def openai_client(client_with_models): + base_url = f"{client_with_models.base_url}/v1/openai/v1" + return OpenAI(base_url=base_url, api_key="fake") diff --git a/tests/integration/openai_responses/test_openai_responses.py b/tests/integration/openai_responses/test_basic.py similarity index 86% rename from tests/integration/openai_responses/test_openai_responses.py rename to tests/integration/openai_responses/test_basic.py index 870c14636..49e94388b 100644 --- a/tests/integration/openai_responses/test_openai_responses.py +++ b/tests/integration/openai_responses/test_basic.py @@ -6,17 +6,10 @@ import pytest -from openai import OpenAI from ..test_cases.test_case import TestCase -@pytest.fixture -def openai_client(client_with_models): - base_url = f"{client_with_models.base_url}/v1/openai/v1" - return OpenAI(base_url=base_url, api_key="bar") - - @pytest.mark.parametrize( "test_case", [ @@ -24,7 +17,7 @@ def openai_client(client_with_models): "openai:responses:non_streaming_02", ], ) -def test_openai_responses_non_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_basic_non_streaming(openai_client, client_with_models, text_model_id, test_case): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -55,7 +48,7 @@ def test_openai_responses_non_streaming(openai_client, client_with_models, text_ "openai:responses:streaming_02", ], ) -def test_openai_responses_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_basic_streaming(openai_client, client_with_models, text_model_id, test_case): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] diff --git a/tests/integration/openai_responses/test_web_search_builtin.py b/tests/integration/openai_responses/test_web_search_builtin.py new file mode 100644 index 000000000..d7baf96d2 --- /dev/null +++ b/tests/integration/openai_responses/test_web_search_builtin.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from ..test_cases.test_case import TestCase + + +@pytest.mark.parametrize( + "test_case", + [ + "openai:responses:tools_web_search_01", + ], +) +def test_web_search_non_streaming(openai_client, client_with_models, text_model_id, test_case): + tc = TestCase(test_case) + input = tc["input"] + expected = tc["expected"] + tools = tc["tools"] + + response = openai_client.responses.create( + model=text_model_id, + input=input, + tools=tools, + stream=False, + ) + assert len(response.output) > 1 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + assert response.output[1].type == "message" + assert response.output[1].status == "completed" + assert response.output[1].role == "assistant" + assert len(response.output[1].content) > 0 + assert expected.lower() in response.output_text.lower().strip() diff --git a/tests/integration/test_cases/openai/responses.json b/tests/integration/test_cases/openai/responses.json index e7a132826..d17d0cd4f 100644 --- a/tests/integration/test_cases/openai/responses.json +++ b/tests/integration/test_cases/openai/responses.json @@ -22,5 +22,16 @@ "question": "What is the name of the US captial?", "expected": "Washington" } + }, + "tools_web_search_01": { + "data": { + "input": "How many experts does the Llama 4 Maverick model have?", + "tools": [ + { + "type": "web_search" + } + ], + "expected": "128" + } } }