mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: function tools in OpenAI Responses (#2094)
# What does this PR do? This is a combination of what was previously 3 separate PRs - #2069, #2075, and #2083. It turns out all 3 of those are needed to land a working function calling Responses implementation. The web search builtin tool was already working, but this wires in support for custom function calling. I ended up combining all three into one PR because they all had lots of merge conflicts, both with each other but also with #1806 that just landed. And, because landing any of them individually would have only left a partially working implementation merged. The new things added here are: * Storing of input items from previous responses and restoring of those input items when adding previous responses to the conversation state * Handling of multiple input item messages roles, not just "user" messages. * Support for custom tools passed into the Responses API to enable function calling outside of just the builtin websearch tool. Closes #2074 Closes #2080 ## Test Plan ### Unit Tests Several new unit tests were added, and they all pass. Ran via: ``` python -m pytest -s -v tests/unit/providers/agents/meta_reference/test_openai_responses.py ``` ### Responses API Verification Tests I ran our verification run.yaml against multiple providers to ensure we were getting a decent pass rate. Specifically, I ensured the new custom tool verification test passed across multiple providers and that the multi-turn examples passed across at least some of the providers (some providers struggle with the multi-turn workflows still). Running the stack setup for verification testing: ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` Together, passing 100% as an example: ``` pytest -s -v 'tests/verifications/openai_api/test_responses.py' --provider=together-llama-stack ``` ## Documentation We will need to start documenting the OpenAI APIs, but for now the Responses stuff is still rapidly evolving so delaying that. --------- Signed-off-by: Derek Higgins <derekh@redhat.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Derek Higgins <derekh@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
e0d10dd0b1
commit
8e316c9b1e
13 changed files with 1099 additions and 370 deletions
|
@ -31,7 +31,7 @@ from llama_stack.apis.tools import ToolDef
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
|
@ -593,7 +593,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
|
|||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||
text: str
|
||||
|
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessage(BaseModel):
|
||||
id: str
|
||||
content: list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
status: str
|
||||
class OpenAIResponseMessage(BaseModel):
|
||||
"""
|
||||
Corresponds to the various Message types in the Responses API.
|
||||
They are all under one type because the Responses API gives them all
|
||||
the same "type" value, and there is no way to tell them apart in certain
|
||||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
# The fields below are not used in all scenarios, but are required in others.
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
|
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
arguments: str
|
||||
call_id: str
|
||||
name: str
|
||||
type: Literal["function_call"] = "function_call"
|
||||
id: str
|
||||
status: str
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||
"""
|
||||
This represents the output of a function call that gets passed back to the model.
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
output: str
|
||||
type: Literal["function_call_output"] = "function_call_output"
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
OpenAIResponseInput = Annotated[
|
||||
# Responses API allows output messages to be passed in as input
|
||||
OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
Field(union_mode="left_to_right"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessage(BaseModel):
|
||||
content: str | list[OpenAIResponseInputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] | None = "message"
|
||||
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
class FileSearchRankingOptions(BaseModel):
|
||||
ranker: str | None = None
|
||||
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_id: list[str]
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
# TODO: add filters
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
class OpenAIResponseInputItemList(BaseModel):
|
||||
data: list[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
|
|
@ -18,6 +18,7 @@ from importlib.metadata import version as parse_version
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
|
@ -187,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
|
|||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Session,
|
||||
|
@ -311,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
|
@ -10,19 +10,26 @@ from collections.abc import AsyncIterator
|
|||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
# TODO: handle image content
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
)
|
||||
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""
|
||||
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||
"""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
return output_messages
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: OpenAIResponseInputItemList
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
|
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
|
|||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponseObject.model_validate_json(response_json)
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input_items.data
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.response.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(id)
|
||||
return response_with_input.response
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
|
|||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
if isinstance(user_input.content, list):
|
||||
for user_input_content in user_input.content:
|
||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
||||
if user_input_content.image_url:
|
||||
image_url = OpenAIImageURL(
|
||||
url=user_input_content.image_url, detail=user_input_content.detail
|
||||
)
|
||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
else:
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
||||
else:
|
||||
user_content = input
|
||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
||||
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
|
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
|
|||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
|
@ -163,7 +256,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call.model_dump())
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
|
@ -181,12 +293,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -195,13 +321,43 @@ class OpenAIResponsesImpl:
|
|||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=response.model_dump_json(),
|
||||
value=prev_response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
@ -221,7 +377,9 @@ class OpenAIResponsesImpl:
|
|||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
|
@ -247,12 +405,11 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
choice: OpenAIChoice,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
|
@ -262,6 +419,9 @@ class OpenAIResponsesImpl:
|
|||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Copy the messages list to avoid mutating the original list
|
||||
messages = messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
|
||||
|
@ -307,7 +467,9 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue