mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
store prev messages
# What does this PR do? ## Test Plan
This commit is contained in:
parent
4819a2e0ee
commit
2ec9f8770e
7 changed files with 202 additions and 58 deletions
|
@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
|
|||
)
|
||||
|
||||
|
||||
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
|
||||
"""Helper to create test messages for chat completion."""
|
||||
return [OpenAIUserMessageParam(content=content)]
|
||||
|
||||
|
||||
async def test_responses_store_pagination_basic():
|
||||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
for response_id, timestamp, model in test_data:
|
||||
response = create_test_response_object(response_id, timestamp, model)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
|
|||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages(f"Input for {response_id}")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
|
|||
# Store a test response
|
||||
response = create_test_response_object("test-resp", int(time.time()))
|
||||
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("Test input content")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
|
|||
create_test_response_input("Fourth input", "input-4"),
|
||||
create_test_response_input("Fifth input", "input-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("First input")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
|
|||
create_test_response_input("Fourth input", "before-4"),
|
||||
create_test_response_input("Fifth input", "before-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
messages = create_test_messages("First input")
|
||||
await store.store_response_object(response, input_list, messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue