mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
chore(responses): Refactor Responses Impl to be civilized (#3138)
# What does this PR do? Refactors the OpenAI responses implementation by extracting streaming and tool execution logic into separate modules. This improves code organization by: 1. Creating a new `StreamingResponseOrchestrator` class in `streaming.py` to handle the streaming response generation logic 2. Moving tool execution functionality to a dedicated `ToolExecutor` class in `tool_executor.py` ## Test Plan Existing tests
This commit is contained in:
parent
e69acbafbf
commit
47d5af703c
10 changed files with 1434 additions and 1156 deletions
0
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Normal file → Executable file
0
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Normal file → Executable file
|
@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
|||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,499 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
MCPListToolsTool,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
vector_io_api=vector_io_api,
|
||||
)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||
|
||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
chat_tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=self.inference_api,
|
||||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
mcp_list_message=mcp_list_message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type == "response.completed":
|
||||
final_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
|
||||
# Store the response if requested
|
||||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> tuple[
|
||||
list[ChatCompletionToolParam],
|
||||
dict[str, OpenAIResponseInputToolMCP],
|
||||
OpenAIResponseOutput | None,
|
||||
]:
|
||||
mcp_tool_to_server = {}
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
mcp_list_message = None
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
if isinstance(input_tool.allowed_tools, list):
|
||||
always_allowed = input_tool.allowed_tools
|
||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = input_tool.allowed_tools.always
|
||||
never_allowed = input_tool.allowed_tools.never
|
||||
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=input_tool.server_url,
|
||||
headers=input_tool.headers or {},
|
||||
)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
server_label=input_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
chat_tools.append(make_openai_tool(t.name, t))
|
||||
if t.name in mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||
mcp_tool_to_server[t.name] = input_tool
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
|
@ -0,0 +1,451 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
ctx: ChatCompletionContext,
|
||||
response_id: str,
|
||||
created_at: int,
|
||||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
mcp_list_message: OpenAIResponseOutput | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
self.response_id = response_id
|
||||
self.created_at = created_at
|
||||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
self.mcp_list_message = mcp_list_message
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages with MCP list message if present
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if self.mcp_list_message:
|
||||
output_messages.append(self.mcp_list_message)
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
text=self.text,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
||||
while True:
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=self.ctx.response_format,
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
"""Process streaming chunks and emit events, returning completion data."""
|
||||
# Initialize result tracking
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
content_part_emitted = False
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
# Create new tool call entry if this is the first chunk for this index
|
||||
is_new_tool_call = response_tool_call is None
|
||||
if is_new_tool_call:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Create item ID for this tool call for streaming events
|
||||
tool_call_item_id = f"fc_{uuid.uuid4()}"
|
||||
tool_call_item_ids[tool_call.index] = tool_call_item_id
|
||||
|
||||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||
self.sequence_number += 1
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = (
|
||||
tool_call.function.name and tool_call.function.name in self.ctx.mcp_tool_to_server
|
||||
)
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
else:
|
||||
# Emit function call argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = (
|
||||
self.ctx.mcp_tool_to_server and tool_call_name and tool_call_name in self.ctx.mcp_tool_to_server
|
||||
)
|
||||
self.sequence_number += 1
|
||||
done_event_cls = (
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||
if is_mcp_tool
|
||||
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||
)
|
||||
yield done_event_cls(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit content_part.done event if text content was streamed (before content gets cleared)
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
yield ChatCompletionResult(
|
||||
response_id=chat_response_id,
|
||||
content=chat_response_content,
|
||||
tool_calls=chat_response_tool_calls,
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
finish_reason=chunk_finish_reason,
|
||||
message_item_id=message_item_id,
|
||||
tool_call_item_ids=tool_call_item_ids,
|
||||
content_part_emitted=content_part_emitted,
|
||||
)
|
||||
|
||||
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||
"""Build OpenAIChatCompletion from ChatCompletionResult."""
|
||||
# Convert collected chunks to complete response
|
||||
if result.tool_calls:
|
||||
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content=result.content_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return OpenAIChatCompletion(
|
||||
id=result.response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=result.finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=result.created,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
async def _coordinate_tool_execution(
|
||||
self,
|
||||
function_tool_calls: list,
|
||||
non_function_tool_calls: list,
|
||||
completion_result_data: ChatCompletionResult,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
next_turn_messages: list,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use a fallback item_id if not found
|
||||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self.tool_executor.execute_tool_call(
|
||||
tool_call, self.ctx, self.sequence_number, len(output_messages), matching_item_id
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
self.sequence_number = result.sequence_number
|
||||
yield result.stream_event
|
||||
|
||||
if result.final_output_message is not None:
|
||||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
||||
# Emit output_item.done event for completed non-function tool call
|
||||
if matching_item_id:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=tool_call_log,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
# Execute function tool calls (client-side)
|
||||
for tool_call in function_tool_calls:
|
||||
# Find the item_id for this tool call from our tracking dictionary
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use existing item_id or create new one if not found
|
||||
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
|
||||
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=final_item_id,
|
||||
status="completed",
|
||||
)
|
||||
output_messages.append(function_call_item)
|
||||
|
||||
# Emit output_item.done event for completed function call
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
|
@ -0,0 +1,365 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
vector_io_api: VectorIO,
|
||||
):
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
||||
async def execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
async def search_single_store(vector_store_id):
|
||||
try:
|
||||
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=response_file_search_tool.filters,
|
||||
max_num_results=response_file_search_tool.max_num_results,
|
||||
ranking_options=response_file_search_tool.ranking_options,
|
||||
rewrite_query=False,
|
||||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||
all_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Flatten results
|
||||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_progress_events(
|
||||
self, function_name: str, ctx: ChatCompletionContext, sequence_number: int, output_index: int, item_id: str
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self, function_name: str, tool_kwargs: dict, ctx: ChatCompletionContext
|
||||
) -> tuple[Exception | None, any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = ctx.mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
return error_exc, result
|
||||
|
||||
async def _emit_completion_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
has_error: bool,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if ctx.mcp_tool_to_server and function_name in ctx.mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
has_error: bool,
|
||||
) -> tuple[any, any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResult:
|
||||
"""Result of processing streaming chat completion chunks."""
|
||||
|
||||
response_id: str
|
||||
content: list[str]
|
||||
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
||||
created: int
|
||||
model: str
|
||||
finish_reason: str
|
||||
message_item_id: str # For streaming events
|
||||
tool_call_item_ids: dict[int, str] # For streaming events
|
||||
content_part_emitted: bool # Tracking state
|
||||
|
||||
@property
|
||||
def content_text(self) -> str:
|
||||
"""Get joined content as string."""
|
||||
return "".join(self.content)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are any tool calls."""
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
|
@ -41,7 +41,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue