store prev messages

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-02 15:53:31 -07:00
parent 4819a2e0ee
commit 2ec9f8770e
7 changed files with 202 additions and 58 deletions

View file

@ -8,7 +8,7 @@ import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
)
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
"""
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
if previous_response.messages:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
messages.extend(new_messages)
else:
new_input_items.extend(input)
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
input = new_input_items
return input
return all_input, messages
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
return response_with_input.to_response_object()
async def list_openai_responses(
self,
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
messages=messages,
)
async def create_openai_response(
@ -224,8 +252,7 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -265,7 +292,8 @@ class OpenAIResponsesImpl:
if store and final_response:
await self._store_response(
response=final_response,
input=input,
input=all_input,
messages=orchestrator.final_messages,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:

View file

@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
@ -103,6 +104,8 @@ class StreamingResponseOrchestrator:
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
@ -192,6 +195,8 @@ class StreamingResponseOrchestrator:
messages = next_turn_messages
self.final_messages = messages.copy() + [current_response.choices[0].message]
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,

View file

@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.log import get_logger
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
logger = get_logger(name=__name__, category="responses_store")
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
"""Internal class for storing responses with chat completion messages.
This extends the public OpenAIResponseObjectWithInput with messages field
for internal storage. The messages field is not exposed in the public API.
The messages field is optional for backward compatibility with responses
stored before this feature was added.
"""
messages: list[OpenAIMessageParam] | None = None
class ResponsesStore:
def __init__(
self,
@ -54,7 +68,9 @@ class ResponsesStore:
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
@ -100,18 +116,21 @@ class ResponsesStore:
await self._queue.join()
async def store_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Responses store is not initialized")
try:
self._queue.put_nowait((response_object, input))
self._queue.put_nowait((response_object, input, messages))
except asyncio.QueueFull:
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
await self._queue.put((response_object, input))
await self._queue.put((response_object, input, messages))
else:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
@ -120,22 +139,26 @@ class ResponsesStore:
item = await self._queue.get()
except asyncio.CancelledError:
break
response_object, input = item
response_object, input, messages = item
try:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing response object: {e}")
finally:
self._queue.task_done()
async def _write_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.sql_store is None:
raise ValueError("Responses store is not initialized")
data = response_object.model_dump()
data["input"] = [input_item.model_dump() for input_item in input]
data["messages"] = [msg.model_dump() for msg in messages]
await self.sql_store.insert(
"openai_responses",
@ -188,7 +211,7 @@ class ResponsesStore:
last_id=data[-1].id if data else "",
)
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
"""
Get a response object with automatic access control checking.
"""
@ -205,7 +228,7 @@ class ResponsesStore:
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None
return OpenAIResponseObjectWithInput(**row["response_object"])
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
if not self.sql_store:
@ -241,8 +264,8 @@ class ResponsesStore:
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input = await self.get_response_object(response_id)
items = response_with_input.input
response_with_input_and_messages = await self.get_response_object(response_id)
items = response_with_input_and_messages.input
if order == Order.desc:
items = list(reversed(items))