mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
store prev messages
# What does this PR do? ## Test Plan
This commit is contained in:
parent
4819a2e0ee
commit
2ec9f8770e
7 changed files with 202 additions and 58 deletions
|
@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||||
|
|
||||||
input: list[OpenAIResponseInput]
|
input: list[OpenAIResponseInput]
|
||||||
|
|
||||||
|
def to_response_object(self) -> OpenAIResponseObject:
|
||||||
|
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||||
|
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListOpenAIResponseObject(BaseModel):
|
class ListOpenAIResponseObject(BaseModel):
|
||||||
|
|
|
@ -8,7 +8,7 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.agents import Order
|
from llama_stack.apis.agents import Order
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
OpenAIMessageParam,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import (
|
||||||
|
ResponsesStore,
|
||||||
|
_OpenAIResponseObjectWithInputAndMessages,
|
||||||
|
)
|
||||||
|
|
||||||
from .streaming import StreamingResponseOrchestrator
|
from .streaming import StreamingResponseOrchestrator
|
||||||
from .tool_executor import ToolExecutor
|
from .tool_executor import ToolExecutor
|
||||||
|
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
|
||||||
async def _prepend_previous_response(
|
async def _prepend_previous_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
previous_response_id: str | None = None,
|
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||||
):
|
):
|
||||||
|
new_input_items = previous_response.input.copy()
|
||||||
|
new_input_items.extend(previous_response.output)
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||||
|
else:
|
||||||
|
new_input_items.extend(input)
|
||||||
|
|
||||||
|
return new_input_items
|
||||||
|
|
||||||
|
async def _process_input_with_previous_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
previous_response_id: str | None,
|
||||||
|
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||||
|
"""Process input with optional previous response context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (all_input for storage, messages for chat completion)
|
||||||
|
"""
|
||||||
if previous_response_id:
|
if previous_response_id:
|
||||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
|
||||||
|
await self.responses_store.get_response_object(previous_response_id)
|
||||||
|
)
|
||||||
|
all_input = await self._prepend_previous_response(input, previous_response)
|
||||||
|
|
||||||
# previous response input items
|
if previous_response.messages:
|
||||||
new_input_items = previous_response_with_input.input
|
# Use stored messages directly and convert only new input
|
||||||
|
message_adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||||
# previous response output items
|
messages = message_adapter.validate_python(previous_response.messages)
|
||||||
new_input_items.extend(previous_response_with_input.output)
|
new_messages = await convert_response_input_to_chat_messages(input)
|
||||||
|
messages.extend(new_messages)
|
||||||
# new input items from the current request
|
|
||||||
if isinstance(input, str):
|
|
||||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
|
||||||
else:
|
else:
|
||||||
new_input_items.extend(input)
|
# Backward compatibility: reconstruct from inputs
|
||||||
|
messages = await convert_response_input_to_chat_messages(all_input)
|
||||||
|
else:
|
||||||
|
all_input = input
|
||||||
|
messages = await convert_response_input_to_chat_messages(input)
|
||||||
|
|
||||||
input = new_input_items
|
return all_input, messages
|
||||||
|
|
||||||
return input
|
|
||||||
|
|
||||||
async def _prepend_instructions(self, messages, instructions):
|
async def _prepend_instructions(self, messages, instructions):
|
||||||
if instructions:
|
if instructions:
|
||||||
|
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
|
||||||
response_id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
return response_with_input.to_response_object()
|
||||||
|
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
self,
|
self,
|
||||||
|
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
|
||||||
self,
|
self,
|
||||||
response: OpenAIResponseObject,
|
response: OpenAIResponseObject,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
|
messages: list[OpenAIMessageParam],
|
||||||
) -> None:
|
) -> None:
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
|
||||||
await self.responses_store.store_response_object(
|
await self.responses_store.store_response_object(
|
||||||
response_object=response,
|
response_object=response,
|
||||||
input=input_items_data,
|
input=input_items_data,
|
||||||
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
|
@ -224,8 +252,7 @@ class OpenAIResponsesImpl:
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Input preprocessing
|
# Input preprocessing
|
||||||
input = await self._prepend_previous_response(input, previous_response_id)
|
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
|
||||||
messages = await convert_response_input_to_chat_messages(input)
|
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
|
|
||||||
# Structured outputs
|
# Structured outputs
|
||||||
|
@ -265,7 +292,8 @@ class OpenAIResponsesImpl:
|
||||||
if store and final_response:
|
if store and final_response:
|
||||||
await self._store_response(
|
await self._store_response(
|
||||||
response=final_response,
|
response=final_response,
|
||||||
input=input,
|
input=all_input,
|
||||||
|
messages=orchestrator.final_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
|
|
|
@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionToolCall,
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
|
OpenAIMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -103,6 +104,8 @@ class StreamingResponseOrchestrator:
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
# Store MCP tool mapping that gets built during tool processing
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
|
# Track final messages after all tool executions
|
||||||
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Initialize output messages
|
# Initialize output messages
|
||||||
|
@ -192,6 +195,8 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
messages = next_turn_messages
|
messages = next_turn_messages
|
||||||
|
|
||||||
|
self.final_messages = messages.copy() + [current_response.choices[0].message]
|
||||||
|
|
||||||
# Create final response
|
# Create final response
|
||||||
final_response = OpenAIResponseObject(
|
final_response = OpenAIResponseObject(
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectWithInput,
|
OpenAIResponseObjectWithInput,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
|
||||||
logger = get_logger(name=__name__, category="responses_store")
|
logger = get_logger(name=__name__, category="responses_store")
|
||||||
|
|
||||||
|
|
||||||
|
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
||||||
|
"""Internal class for storing responses with chat completion messages.
|
||||||
|
|
||||||
|
This extends the public OpenAIResponseObjectWithInput with messages field
|
||||||
|
for internal storage. The messages field is not exposed in the public API.
|
||||||
|
|
||||||
|
The messages field is optional for backward compatibility with responses
|
||||||
|
stored before this feature was added.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[OpenAIMessageParam] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponsesStore:
|
class ResponsesStore:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -54,7 +68,9 @@ class ResponsesStore:
|
||||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||||
|
|
||||||
# Async write queue and worker control
|
# Async write queue and worker control
|
||||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
self._queue: (
|
||||||
|
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||||
|
) = None
|
||||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||||
self._max_write_queue_size: int = config.max_write_queue_size
|
self._max_write_queue_size: int = config.max_write_queue_size
|
||||||
self._num_writers: int = max(1, config.num_writers)
|
self._num_writers: int = max(1, config.num_writers)
|
||||||
|
@ -100,18 +116,21 @@ class ResponsesStore:
|
||||||
await self._queue.join()
|
await self._queue.join()
|
||||||
|
|
||||||
async def store_response_object(
|
async def store_response_object(
|
||||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
self,
|
||||||
|
response_object: OpenAIResponseObject,
|
||||||
|
input: list[OpenAIResponseInput],
|
||||||
|
messages: list[OpenAIMessageParam],
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.enable_write_queue:
|
if self.enable_write_queue:
|
||||||
if self._queue is None:
|
if self._queue is None:
|
||||||
raise ValueError("Responses store is not initialized")
|
raise ValueError("Responses store is not initialized")
|
||||||
try:
|
try:
|
||||||
self._queue.put_nowait((response_object, input))
|
self._queue.put_nowait((response_object, input, messages))
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||||
await self._queue.put((response_object, input))
|
await self._queue.put((response_object, input, messages))
|
||||||
else:
|
else:
|
||||||
await self._write_response_object(response_object, input)
|
await self._write_response_object(response_object, input, messages)
|
||||||
|
|
||||||
async def _worker_loop(self) -> None:
|
async def _worker_loop(self) -> None:
|
||||||
assert self._queue is not None
|
assert self._queue is not None
|
||||||
|
@ -120,22 +139,26 @@ class ResponsesStore:
|
||||||
item = await self._queue.get()
|
item = await self._queue.get()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
response_object, input = item
|
response_object, input, messages = item
|
||||||
try:
|
try:
|
||||||
await self._write_response_object(response_object, input)
|
await self._write_response_object(response_object, input, messages)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
logger.error(f"Error writing response object: {e}")
|
logger.error(f"Error writing response object: {e}")
|
||||||
finally:
|
finally:
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
async def _write_response_object(
|
async def _write_response_object(
|
||||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
self,
|
||||||
|
response_object: OpenAIResponseObject,
|
||||||
|
input: list[OpenAIResponseInput],
|
||||||
|
messages: list[OpenAIMessageParam],
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.sql_store is None:
|
if self.sql_store is None:
|
||||||
raise ValueError("Responses store is not initialized")
|
raise ValueError("Responses store is not initialized")
|
||||||
|
|
||||||
data = response_object.model_dump()
|
data = response_object.model_dump()
|
||||||
data["input"] = [input_item.model_dump() for input_item in input]
|
data["input"] = [input_item.model_dump() for input_item in input]
|
||||||
|
data["messages"] = [msg.model_dump() for msg in messages]
|
||||||
|
|
||||||
await self.sql_store.insert(
|
await self.sql_store.insert(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
|
@ -188,7 +211,7 @@ class ResponsesStore:
|
||||||
last_id=data[-1].id if data else "",
|
last_id=data[-1].id if data else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
||||||
"""
|
"""
|
||||||
Get a response object with automatic access control checking.
|
Get a response object with automatic access control checking.
|
||||||
"""
|
"""
|
||||||
|
@ -205,7 +228,7 @@ class ResponsesStore:
|
||||||
# This provides security by not revealing whether the record exists
|
# This provides security by not revealing whether the record exists
|
||||||
raise ValueError(f"Response with id {response_id} not found") from None
|
raise ValueError(f"Response with id {response_id} not found") from None
|
||||||
|
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
||||||
|
|
||||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
|
@ -241,8 +264,8 @@ class ResponsesStore:
|
||||||
if before and after:
|
if before and after:
|
||||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||||
|
|
||||||
response_with_input = await self.get_response_object(response_id)
|
response_with_input_and_messages = await self.get_response_object(response_id)
|
||||||
items = response_with_input.input
|
items = response_with_input_and_messages.input
|
||||||
|
|
||||||
if order == Order.desc:
|
if order == Order.desc:
|
||||||
items = list(reversed(items))
|
items = list(reversed(items))
|
||||||
|
|
|
@ -127,6 +127,70 @@ def test_response_non_streaming_file_search_empty_vector_store(compat_client, te
|
||||||
assert response.output_text
|
assert response.output_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_sequential_file_search(compat_client, text_model_id, tmp_path):
|
||||||
|
"""Test file search with sequential responses using previous_response_id."""
|
||||||
|
if isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||||
|
|
||||||
|
vector_store = new_vector_store(compat_client, "test_vector_store")
|
||||||
|
|
||||||
|
# Create a test file with content
|
||||||
|
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
|
||||||
|
file_name = "test_sequential_file_search.txt"
|
||||||
|
file_path = tmp_path / file_name
|
||||||
|
file_path.write_text(file_content)
|
||||||
|
|
||||||
|
file_response = upload_file(compat_client, file_name, file_path)
|
||||||
|
|
||||||
|
# Attach the file to the vector store
|
||||||
|
compat_client.vector_stores.files.create(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
file_id=file_response.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the file to be attached
|
||||||
|
wait_for_file_attachment(compat_client, vector_store.id, file_response.id)
|
||||||
|
|
||||||
|
tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]
|
||||||
|
|
||||||
|
# First response request with file search
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input="How many experts does the Llama 4 Maverick model have?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the file_search_tool was called
|
||||||
|
assert len(response.output) > 1
|
||||||
|
assert response.output[0].type == "file_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].queries
|
||||||
|
assert response.output[0].results
|
||||||
|
assert "128" in response.output_text or "experts" in response.output_text.lower()
|
||||||
|
|
||||||
|
# Second response request using previous_response_id
|
||||||
|
response2 = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input="Can you tell me more about the architecture?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
previous_response_id=response.id,
|
||||||
|
include=["file_search_call.results"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the second response has output
|
||||||
|
assert len(response2.output) >= 1
|
||||||
|
assert response2.output_text
|
||||||
|
|
||||||
|
# The second response should maintain context from the first
|
||||||
|
final_message = [output for output in response2.output if output.type == "message"]
|
||||||
|
assert len(final_message) >= 1
|
||||||
|
assert final_message[-1].role == "assistant"
|
||||||
|
assert final_message[-1].status == "completed"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
@pytest.mark.parametrize("case", mcp_tool_test_cases)
|
||||||
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
|
def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case):
|
||||||
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
if not isinstance(compat_client, LlamaStackAsLibraryClient):
|
||||||
|
|
|
@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolFunction,
|
||||||
OpenAIResponseInputToolWebSearch,
|
OpenAIResponseInputToolWebSearch,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObjectWithInput,
|
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
@ -45,7 +44,10 @@ from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||||
OpenAIResponsesImpl,
|
OpenAIResponsesImpl,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import (
|
||||||
|
ResponsesStore,
|
||||||
|
_OpenAIResponseObjectWithInputAndMessages,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||||
|
|
||||||
|
@ -498,13 +500,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
||||||
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||||
|
|
||||||
|
|
||||||
async def test_prepend_previous_response_none(openai_responses_impl):
|
|
||||||
"""Test prepending no previous response to a new response."""
|
|
||||||
|
|
||||||
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
|
|
||||||
assert input == "fake_input"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
||||||
"""Test prepending a basic previous response to a new response."""
|
"""Test prepending a basic previous response to a new response."""
|
||||||
|
|
||||||
|
@ -519,7 +514,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
previous_response = OpenAIResponseObjectWithInput(
|
previous_response = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
created_at=1,
|
created_at=1,
|
||||||
id="resp_123",
|
id="resp_123",
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
|
@ -527,10 +522,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
||||||
status="completed",
|
status="completed",
|
||||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
|
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = previous_response
|
mock_responses_store.get_response_object.return_value = previous_response
|
||||||
|
|
||||||
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
|
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
|
||||||
|
|
||||||
assert len(input) == 3
|
assert len(input) == 3
|
||||||
# Check for previous input
|
# Check for previous input
|
||||||
|
@ -561,7 +557,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
response = OpenAIResponseObjectWithInput(
|
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
created_at=1,
|
created_at=1,
|
||||||
id="resp_123",
|
id="resp_123",
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
|
@ -569,11 +565,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
||||||
status="completed",
|
status="completed",
|
||||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
|
messages=[OpenAIUserMessageParam(content="test input")],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = response
|
mock_responses_store.get_response_object.return_value = response
|
||||||
|
|
||||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
|
||||||
|
|
||||||
assert len(input) == 4
|
assert len(input) == 4
|
||||||
# Check for previous input
|
# Check for previous input
|
||||||
|
@ -608,7 +605,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
response = OpenAIResponseObjectWithInput(
|
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
created_at=1,
|
created_at=1,
|
||||||
id="resp_123",
|
id="resp_123",
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
|
@ -616,11 +613,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
|
||||||
status="completed",
|
status="completed",
|
||||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
|
messages=[OpenAIUserMessageParam(content="test input")],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = response
|
mock_responses_store.get_response_object.return_value = response
|
||||||
|
|
||||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
|
||||||
|
|
||||||
assert len(input) == 4
|
assert len(input) == 4
|
||||||
# Check for previous input
|
# Check for previous input
|
||||||
|
@ -724,7 +722,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
response = OpenAIResponseObjectWithInput(
|
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
created_at=1,
|
created_at=1,
|
||||||
id="resp_123",
|
id="resp_123",
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
|
@ -732,6 +730,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
status="completed",
|
status="completed",
|
||||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
|
messages=[
|
||||||
|
OpenAIUserMessageParam(content="Name some towns in Ireland"),
|
||||||
|
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = response
|
mock_responses_store.get_response_object.return_value = response
|
||||||
|
|
||||||
|
@ -817,7 +819,7 @@ async def test_responses_store_list_input_items_logic():
|
||||||
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
|
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
response_with_input = OpenAIResponseObjectWithInput(
|
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
id="resp_123",
|
id="resp_123",
|
||||||
model="test_model",
|
model="test_model",
|
||||||
created_at=1234567890,
|
created_at=1234567890,
|
||||||
|
@ -826,6 +828,7 @@ async def test_responses_store_list_input_items_logic():
|
||||||
output=[],
|
output=[],
|
||||||
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
|
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
|
||||||
input=input_items,
|
input=input_items,
|
||||||
|
messages=[OpenAIUserMessageParam(content="First message")],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the get_response_object method to return our test data
|
# Mock the get_response_object method to return our test data
|
||||||
|
@ -886,7 +889,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
rather than just the original input when previous_response_id is provided."""
|
rather than just the original input when previous_response_id is provided."""
|
||||||
|
|
||||||
# Setup - Create a previous response that should be included in the stored input
|
# Setup - Create a previous response that should be included in the stored input
|
||||||
previous_response = OpenAIResponseObjectWithInput(
|
previous_response = _OpenAIResponseObjectWithInputAndMessages(
|
||||||
id="resp-previous-123",
|
id="resp-previous-123",
|
||||||
object="response",
|
object="response",
|
||||||
created_at=1234567890,
|
created_at=1234567890,
|
||||||
|
@ -905,6 +908,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
|
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
messages=[
|
||||||
|
OpenAIUserMessageParam(content="What is 2+2?"),
|
||||||
|
OpenAIAssistantMessageParam(content="2+2 equals 4."),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_responses_store.get_response_object.return_value = previous_response
|
mock_responses_store.get_response_object.return_value = previous_response
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
|
||||||
|
"""Helper to create test messages for chat completion."""
|
||||||
|
return [OpenAIUserMessageParam(content=content)]
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_pagination_basic():
|
async def test_responses_store_pagination_basic():
|
||||||
"""Test basic pagination functionality for responses store."""
|
"""Test basic pagination functionality for responses store."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
|
||||||
for response_id, timestamp in test_data:
|
for response_id, timestamp in test_data:
|
||||||
response = create_test_response_object(response_id, timestamp)
|
response = create_test_response_object(response_id, timestamp)
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages(f"Input for {response_id}")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
|
||||||
for response_id, timestamp in test_data:
|
for response_id, timestamp in test_data:
|
||||||
response = create_test_response_object(response_id, timestamp)
|
response = create_test_response_object(response_id, timestamp)
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages(f"Input for {response_id}")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
|
||||||
for response_id, timestamp, model in test_data:
|
for response_id, timestamp, model in test_data:
|
||||||
response = create_test_response_object(response_id, timestamp, model)
|
response = create_test_response_object(response_id, timestamp, model)
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages(f"Input for {response_id}")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
|
||||||
for response_id, timestamp in test_data:
|
for response_id, timestamp in test_data:
|
||||||
response = create_test_response_object(response_id, timestamp)
|
response = create_test_response_object(response_id, timestamp)
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages(f"Input for {response_id}")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
|
||||||
# Store a test response
|
# Store a test response
|
||||||
response = create_test_response_object("test-resp", int(time.time()))
|
response = create_test_response_object("test-resp", int(time.time()))
|
||||||
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages("Test input content")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
|
||||||
create_test_response_input("Fourth input", "input-4"),
|
create_test_response_input("Fourth input", "input-4"),
|
||||||
create_test_response_input("Fifth input", "input-5"),
|
create_test_response_input("Fifth input", "input-5"),
|
||||||
]
|
]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages("First input")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
|
||||||
create_test_response_input("Fourth input", "before-4"),
|
create_test_response_input("Fourth input", "before-4"),
|
||||||
create_test_response_input("Fifth input", "before-5"),
|
create_test_response_input("Fifth input", "before-5"),
|
||||||
]
|
]
|
||||||
await store.store_response_object(response, input_list)
|
messages = create_test_messages("First input")
|
||||||
|
await store.store_response_object(response, input_list, messages)
|
||||||
|
|
||||||
# Wait for all queued writes to complete
|
# Wait for all queued writes to complete
|
||||||
await store.flush()
|
await store.flush()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue