OpenAI Responses API: Stub in basic web_search tool

This commit is contained in:
Ben Browning 2025-04-17 20:25:36 -04:00 committed by Ashwin Bharambe
parent 52a69f0bf9
commit 35b2e2646f
8 changed files with 232 additions and 15 deletions

View file

@ -9,7 +9,7 @@ from typing import AsyncIterator, List, Literal, Optional, Protocol, Union, runt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type @json_schema_type
@ -28,6 +28,7 @@ OpenAIResponseOutputMessageContent = Annotated[
Union[OpenAIResponseOutputMessageContentOutputText,], Union[OpenAIResponseOutputMessageContentOutputText,],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@json_schema_type @json_schema_type
@ -39,10 +40,21 @@ class OpenAIResponseOutputMessage(BaseModel):
type: Literal["message"] = "message" type: Literal["message"] = "message"
@json_schema_type
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
id: str
status: str
type: Literal["web_search_call"] = "web_search_call"
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
Union[OpenAIResponseOutputMessage,], Union[
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageWebSearchToolCall,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@json_schema_type @json_schema_type
@ -68,6 +80,20 @@ class OpenAIResponseObjectStream(BaseModel):
type: Literal["response.created"] = "response.created" type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search"
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
OpenAIResponseInputTool = Annotated[
Union[OpenAIResponseInputToolWebSearch,],
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
@runtime_checkable @runtime_checkable
class OpenAIResponses(Protocol): class OpenAIResponses(Protocol):
""" """
@ -88,4 +114,5 @@ class OpenAIResponses(Protocol):
previous_response_id: Optional[str] = None, previous_response_id: Optional[str] = None,
store: Optional[bool] = True, store: Optional[bool] = True,
stream: Optional[bool] = False, stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ... ) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ...

View file

@ -14,6 +14,8 @@ from .config import OpenAIResponsesImplConfig
async def get_provider_impl(config: OpenAIResponsesImplConfig, deps: Dict[Api, Any]): async def get_provider_impl(config: OpenAIResponsesImplConfig, deps: Dict[Api, Any]):
from .openai_responses import OpenAIResponsesImpl from .openai_responses import OpenAIResponsesImpl
impl = OpenAIResponsesImpl(config, deps[Api.models], deps[Api.inference]) impl = OpenAIResponsesImpl(
config, deps[Api.models], deps[Api.inference], deps[Api.tool_groups], deps[Api.tool_runtime]
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,27 +4,38 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import uuid import uuid
from typing import AsyncIterator, List, Optional, cast from typing import AsyncIterator, List, Optional, cast
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice, OpenAIChoice,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
) )
from llama_stack.apis.models.models import Models, ModelType from llama_stack.apis.models.models import Models, ModelType
from llama_stack.apis.openai_responses import OpenAIResponses from llama_stack.apis.openai_responses import OpenAIResponses
from llama_stack.apis.openai_responses.openai_responses import ( from llama_stack.apis.openai_responses.openai_responses import (
OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseOutput,
OpenAIResponseOutputMessage, OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import OpenAIResponsesImplConfig from .config import OpenAIResponsesImplConfig
@ -37,7 +48,8 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]: async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = [] messages: List[OpenAIMessageParam] = []
for output_message in previous_response.output: for output_message in previous_response.output:
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages return messages
@ -61,10 +73,19 @@ async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> Lis
class OpenAIResponsesImpl(OpenAIResponses): class OpenAIResponsesImpl(OpenAIResponses):
def __init__(self, config: OpenAIResponsesImplConfig, models_api: Models, inference_api: Inference): def __init__(
self,
config: OpenAIResponsesImplConfig,
models_api: Models,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
):
self.config = config self.config = config
self.models_api = models_api self.models_api = models_api
self.inference_api = inference_api self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
@ -90,6 +111,7 @@ class OpenAIResponsesImpl(OpenAIResponses):
previous_response_id: Optional[str] = None, previous_response_id: Optional[str] = None,
store: Optional[bool] = True, store: Optional[bool] = True,
stream: Optional[bool] = False, stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
): ):
model_obj = await self.models_api.get_model(model) model_obj = await self.models_api.get_model(model)
if model_obj is None: if model_obj is None:
@ -103,14 +125,24 @@ class OpenAIResponsesImpl(OpenAIResponses):
messages.extend(await _previous_response_to_messages(previous_response)) messages.extend(await _previous_response_to_messages(previous_response))
messages.append(OpenAIUserMessageParam(content=input)) messages.append(OpenAIUserMessageParam(content=input))
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion( chat_response = await self.inference_api.openai_chat_completion(
model=model_obj.identifier, model=model_obj.identifier,
messages=messages, messages=messages,
tools=chat_tools,
) )
# type cast to appease mypy # type cast to appease mypy
chat_response = cast(OpenAIChatCompletion, chat_response) chat_response = cast(OpenAIChatCompletion, chat_response)
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion.model_validate_json(chat_response.model_dump_json())
output_messages = await _openai_choices_to_output_messages(chat_response.choices) output_messages: List[OpenAIResponseOutput] = []
if chat_response.choices[0].finish_reason == "tool_calls":
output_messages.extend(
await self._execute_tool_and_return_final_output(model_obj.identifier, chat_response, messages)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
response = OpenAIResponseObject( response = OpenAIResponseObject(
created_at=chat_response.created, created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}", id=f"resp-{uuid.uuid4()}",
@ -136,3 +168,108 @@ class OpenAIResponsesImpl(OpenAIResponses):
return async_response() return async_response()
return response return response
async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
async def _execute_tool_and_return_final_output(
self, model_id: str, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# TODO: handle multiple tool calls
function = choice.message.tool_calls[0].function
# If the tool call is not a function, we don't need to execute it
if not function:
return output_messages
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
tool_call_prefix = "tc_"
if function.name == "web_search":
tool_call_prefix = "ws_"
tool_call_id = f"{tool_call_prefix}{uuid.uuid4()}"
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
return output_messages
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
)
# type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result

View file

@ -20,6 +20,8 @@ def available_providers() -> List[ProviderSpec]:
api_dependencies=[ api_dependencies=[
Api.models, Api.models,
Api.inference, Api.inference,
Api.tool_groups,
Api.tool_runtime,
], ],
), ),
] ]

View file

@ -14,6 +14,7 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from openai import OpenAI
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -207,3 +208,9 @@ def llama_stack_client(request, provider_data, text_model_id):
raise RuntimeError("Initialization failed") raise RuntimeError("Initialization failed")
return client return client
@pytest.fixture(scope="session")
def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1"
return OpenAI(base_url=base_url, api_key="fake")

View file

@ -6,17 +6,10 @@
import pytest import pytest
from openai import OpenAI
from ..test_cases.test_case import TestCase from ..test_cases.test_case import TestCase
@pytest.fixture
def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1"
return OpenAI(base_url=base_url, api_key="bar")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [
@ -24,7 +17,7 @@ def openai_client(client_with_models):
"openai:responses:non_streaming_02", "openai:responses:non_streaming_02",
], ],
) )
def test_openai_responses_non_streaming(openai_client, client_with_models, text_model_id, test_case): def test_basic_non_streaming(openai_client, client_with_models, text_model_id, test_case):
tc = TestCase(test_case) tc = TestCase(test_case)
question = tc["question"] question = tc["question"]
expected = tc["expected"] expected = tc["expected"]
@ -55,7 +48,7 @@ def test_openai_responses_non_streaming(openai_client, client_with_models, text_
"openai:responses:streaming_02", "openai:responses:streaming_02",
], ],
) )
def test_openai_responses_streaming(openai_client, client_with_models, text_model_id, test_case): def test_basic_streaming(openai_client, client_with_models, text_model_id, test_case):
tc = TestCase(test_case) tc = TestCase(test_case)
question = tc["question"] question = tc["question"]
expected = tc["expected"] expected = tc["expected"]

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..test_cases.test_case import TestCase
@pytest.mark.parametrize(
"test_case",
[
"openai:responses:tools_web_search_01",
],
)
def test_web_search_non_streaming(openai_client, client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
input = tc["input"]
expected = tc["expected"]
tools = tc["tools"]
response = openai_client.responses.create(
model=text_model_id,
input=input,
tools=tools,
stream=False,
)
assert len(response.output) > 1
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
assert response.output[1].type == "message"
assert response.output[1].status == "completed"
assert response.output[1].role == "assistant"
assert len(response.output[1].content) > 0
assert expected.lower() in response.output_text.lower().strip()

View file

@ -22,5 +22,16 @@
"question": "What is the name of the US captial?", "question": "What is the name of the US captial?",
"expected": "Washington" "expected": "Washington"
} }
},
"tools_web_search_01": {
"data": {
"input": "How many experts does the Llama 4 Maverick model have?",
"tools": [
{
"type": "web_search"
}
],
"expected": "128"
}
} }
} }