mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
OpenAI Responses API: Stub in basic web_search tool
This commit is contained in:
parent
52a69f0bf9
commit
35b2e2646f
8 changed files with 232 additions and 15 deletions
|
@ -9,7 +9,7 @@ from typing import AsyncIterator, List, Literal, Optional, Protocol, Union, runt
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -28,6 +28,7 @@ OpenAIResponseOutputMessageContent = Annotated[
|
|||
Union[OpenAIResponseOutputMessageContentOutputText,],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -39,10 +40,21 @@ class OpenAIResponseOutputMessage(BaseModel):
|
|||
type: Literal["message"] = "message"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
id: str
|
||||
status: str
|
||||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
Union[OpenAIResponseOutputMessage,],
|
||||
Union[
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -68,6 +80,20 @@ class OpenAIResponseObjectStream(BaseModel):
|
|||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
type: Literal["web_search", "web_search_preview_2025_03_11"] = "web_search"
|
||||
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
Union[OpenAIResponseInputToolWebSearch,],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OpenAIResponses(Protocol):
|
||||
"""
|
||||
|
@ -88,4 +114,5 @@ class OpenAIResponses(Protocol):
|
|||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ...
|
||||
|
|
|
@ -14,6 +14,8 @@ from .config import OpenAIResponsesImplConfig
|
|||
async def get_provider_impl(config: OpenAIResponsesImplConfig, deps: Dict[Api, Any]):
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
|
||||
impl = OpenAIResponsesImpl(config, deps[Api.models], deps[Api.inference])
|
||||
impl = OpenAIResponsesImpl(
|
||||
config, deps[Api.models], deps[Api.inference], deps[Api.tool_groups], deps[Api.tool_runtime]
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,27 +4,38 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.models.models import Models, ModelType
|
||||
from llama_stack.apis.openai_responses import OpenAIResponses
|
||||
from llama_stack.apis.openai_responses.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import OpenAIResponsesImplConfig
|
||||
|
@ -37,7 +48,8 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
|||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
||||
|
||||
|
@ -61,10 +73,19 @@ async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> Lis
|
|||
|
||||
|
||||
class OpenAIResponsesImpl(OpenAIResponses):
|
||||
def __init__(self, config: OpenAIResponsesImplConfig, models_api: Models, inference_api: Inference):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenAIResponsesImplConfig,
|
||||
models_api: Models,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
):
|
||||
self.config = config
|
||||
self.models_api = models_api
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
@ -90,6 +111,7 @@ class OpenAIResponsesImpl(OpenAIResponses):
|
|||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
):
|
||||
model_obj = await self.models_api.get_model(model)
|
||||
if model_obj is None:
|
||||
|
@ -103,14 +125,24 @@ class OpenAIResponsesImpl(OpenAIResponses):
|
|||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
chat_response = cast(OpenAIChatCompletion, chat_response)
|
||||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion.model_validate_json(chat_response.model_dump_json())
|
||||
|
||||
output_messages = await _openai_choices_to_output_messages(chat_response.choices)
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].finish_reason == "tool_calls":
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model_obj.identifier, chat_response, messages)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -136,3 +168,108 @@ class OpenAIResponsesImpl(OpenAIResponses):
|
|||
return async_response()
|
||||
|
||||
return response
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: List[OpenAIResponseInputTool]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
chat_tools: List[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
chat_tools.append(chat_tool)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self, model_id: str, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
|
||||
) -> List[OpenAIResponseOutput]:
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages
|
||||
|
||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# TODO: handle multiple tool calls
|
||||
function = choice.message.tool_calls[0].function
|
||||
|
||||
# If the tool call is not a function, we don't need to execute it
|
||||
if not function:
|
||||
return output_messages
|
||||
|
||||
# TODO: telemetry spans for tool calls
|
||||
result = await self._execute_tool_call(function)
|
||||
|
||||
tool_call_prefix = "tc_"
|
||||
if function.name == "web_search":
|
||||
tool_call_prefix = "ws_"
|
||||
tool_call_id = f"{tool_call_prefix}{uuid.uuid4()}"
|
||||
|
||||
# Handle tool call failure
|
||||
if not result:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
return output_messages
|
||||
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
|
||||
result_content = ""
|
||||
# TODO: handle other result content types and lists
|
||||
if isinstance(result.content, str):
|
||||
result_content = result.content
|
||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> Optional[ToolInvocationResult]:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=function_args,
|
||||
)
|
||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
||||
return result
|
||||
|
|
|
@ -20,6 +20,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api_dependencies=[
|
||||
Api.models,
|
||||
Api.inference,
|
||||
Api.tool_groups,
|
||||
Api.tool_runtime,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -14,6 +14,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
import yaml
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -207,3 +208,9 @@ def llama_stack_client(request, provider_data, text_model_id):
|
|||
raise RuntimeError("Initialization failed")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
|
|
@ -6,17 +6,10 @@
|
|||
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
@ -24,7 +17,7 @@ def openai_client(client_with_models):
|
|||
"openai:responses:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_responses_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_basic_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
@ -55,7 +48,7 @@ def test_openai_responses_non_streaming(openai_client, client_with_models, text_
|
|||
"openai:responses:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_responses_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_basic_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"openai:responses:tools_web_search_01",
|
||||
],
|
||||
)
|
||||
def test_web_search_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
input = tc["input"]
|
||||
expected = tc["expected"]
|
||||
tools = tc["tools"]
|
||||
|
||||
response = openai_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "message"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[1].role == "assistant"
|
||||
assert len(response.output[1].content) > 0
|
||||
assert expected.lower() in response.output_text.lower().strip()
|
|
@ -22,5 +22,16 @@
|
|||
"question": "What is the name of the US captial?",
|
||||
"expected": "Washington"
|
||||
}
|
||||
},
|
||||
"tools_web_search_01": {
|
||||
"data": {
|
||||
"input": "How many experts does the Llama 4 Maverick model have?",
|
||||
"tools": [
|
||||
{
|
||||
"type": "web_search"
|
||||
}
|
||||
],
|
||||
"expected": "128"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue