mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
store prev messages
# What does this PR do? ## Test Plan
This commit is contained in:
parent
4819a2e0ee
commit
2ec9f8770e
7 changed files with 202 additions and 58 deletions
|
@ -8,7 +8,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
|
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
|
|||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
return new_input_items
|
||||
|
||||
async def _process_input_with_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
tuple: (all_input for storage, messages for chat completion)
|
||||
"""
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||
await self.responses_store.get_response_object(previous_response_id)
|
||||
)
|
||||
all_input = await self._prepend_previous_response(input, previous_response)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
if previous_response.messages:
|
||||
# Use stored messages directly and convert only new input
|
||||
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
messages = message_adapter.validate_python(previous_response.messages)
|
||||
new_messages = await convert_response_input_to_chat_messages(input)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
# Backward compatibility: reconstruct from inputs
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
return all_input, messages
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
|
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
|
|||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
return response_with_input.to_response_object()
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
|
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
|
|||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
|
@ -224,8 +252,7 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
|
@ -265,7 +292,8 @@ class OpenAIResponsesImpl:
|
|||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
|
|
|
@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -103,6 +104,8 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
@ -192,6 +195,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
messages = next_turn_messages
|
||||
|
||||
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
|
|||
logger = get_logger(name=__name__, category="responses_store")
|
||||
|
||||
|
||||
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
||||
"""Internal class for storing responses with chat completion messages.
|
||||
|
||||
This extends the public OpenAIResponseObjectWithInput with messages field
|
||||
for internal storage. The messages field is not exposed in the public API.
|
||||
|
||||
The messages field is optional for backward compatibility with responses
|
||||
stored before this feature was added.
|
||||
"""
|
||||
|
||||
messages: list[OpenAIMessageParam] | None = None
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -54,7 +68,9 @@ class ResponsesStore:
|
|||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
@ -100,18 +116,21 @@ class ResponsesStore:
|
|||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input))
|
||||
self._queue.put_nowait((response_object, input, messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input))
|
||||
await self._queue.put((response_object, input, messages))
|
||||
else:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
|
@ -120,22 +139,26 @@ class ResponsesStore:
|
|||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input = item
|
||||
response_object, input, messages = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input)
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
data["messages"] = [msg.model_dump() for msg in messages]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
|
@ -188,7 +211,7 @@ class ResponsesStore:
|
|||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
||||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
|
@ -205,7 +228,7 @@ class ResponsesStore:
|
|||
# This provides security by not revealing whether the record exists
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
|
||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
if not self.sql_store:
|
||||
|
@ -241,8 +264,8 @@ class ResponsesStore:
|
|||
if before and after:
|
||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||
|
||||
response_with_input = await self.get_response_object(response_id)
|
||||
items = response_with_input.input
|
||||
response_with_input_and_messages = await self.get_response_object(response_id)
|
||||
items = response_with_input_and_messages.input
|
||||
|
||||
if order == Order.desc:
|
||||
items = list(reversed(items))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue