mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
268 lines
9.9 KiB
Python
268 lines
9.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.agents import (
|
|
Order,
|
|
)
|
|
from llama_stack.apis.agents.openai_responses import (
|
|
ListOpenAIResponseInputItem,
|
|
ListOpenAIResponseObject,
|
|
OpenAIDeleteResponseObject,
|
|
OpenAIResponseInput,
|
|
OpenAIResponseObject,
|
|
OpenAIResponseObjectWithInput,
|
|
)
|
|
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
|
from llama_stack.log import get_logger
|
|
|
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
|
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
|
|
|
logger = get_logger(name=__name__, category="responses_store")
|
|
|
|
|
|
class ResponsesStore:
|
|
def __init__(
|
|
self,
|
|
config: ResponsesStoreConfig | SqlStoreConfig,
|
|
policy: list[AccessRule],
|
|
):
|
|
# Handle backward compatibility
|
|
if not isinstance(config, ResponsesStoreConfig):
|
|
# Legacy: SqlStoreConfig passed directly as config
|
|
config = ResponsesStoreConfig(
|
|
sql_store_config=config,
|
|
)
|
|
|
|
self.config = config
|
|
self.sql_store_config = config.sql_store_config
|
|
if not self.sql_store_config:
|
|
self.sql_store_config = SqliteSqlStoreConfig(
|
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
|
)
|
|
self.sql_store = None
|
|
self.policy = policy
|
|
|
|
# Disable write queue for SQLite to avoid concurrency issues
|
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
|
|
|
# Async write queue and worker control
|
|
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
|
self._max_write_queue_size: int = config.max_write_queue_size
|
|
self._num_writers: int = max(1, config.num_writers)
|
|
|
|
async def initialize(self):
|
|
"""Create the necessary tables if they don't exist."""
|
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
|
await self.sql_store.create_table(
|
|
"openai_responses",
|
|
{
|
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
|
"created_at": ColumnType.INTEGER,
|
|
"response_object": ColumnType.JSON,
|
|
"model": ColumnType.STRING,
|
|
},
|
|
)
|
|
|
|
if self.enable_write_queue:
|
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
for _ in range(self._num_writers):
|
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
else:
|
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
|
|
|
async def shutdown(self) -> None:
|
|
if not self._worker_tasks:
|
|
return
|
|
if self._queue is not None:
|
|
await self._queue.join()
|
|
for t in self._worker_tasks:
|
|
if not t.done():
|
|
t.cancel()
|
|
for t in self._worker_tasks:
|
|
try:
|
|
await t
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._worker_tasks.clear()
|
|
|
|
async def flush(self) -> None:
|
|
"""Wait for all queued writes to complete. Useful for testing."""
|
|
if self.enable_write_queue and self._queue is not None:
|
|
await self._queue.join()
|
|
|
|
async def store_response_object(
|
|
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
|
) -> None:
|
|
if self.enable_write_queue:
|
|
if self._queue is None:
|
|
raise ValueError("Responses store is not initialized")
|
|
try:
|
|
self._queue.put_nowait((response_object, input))
|
|
except asyncio.QueueFull:
|
|
logger.warning(
|
|
f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}"
|
|
)
|
|
await self._queue.put((response_object, input))
|
|
else:
|
|
await self._write_response_object(response_object, input)
|
|
|
|
async def _worker_loop(self) -> None:
|
|
assert self._queue is not None
|
|
while True:
|
|
try:
|
|
item = await self._queue.get()
|
|
except asyncio.CancelledError:
|
|
break
|
|
response_object, input = item
|
|
try:
|
|
await self._write_response_object(response_object, input)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.error(f"Error writing response object: {e}")
|
|
finally:
|
|
self._queue.task_done()
|
|
|
|
async def _write_response_object(
|
|
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
|
) -> None:
|
|
if self.sql_store is None:
|
|
raise ValueError("Responses store is not initialized")
|
|
|
|
data = response_object.model_dump()
|
|
data["input"] = [input_item.model_dump() for input_item in input]
|
|
|
|
await self.sql_store.insert(
|
|
"openai_responses",
|
|
{
|
|
"id": data["id"],
|
|
"created_at": data["created_at"],
|
|
"model": data["model"],
|
|
"response_object": data,
|
|
},
|
|
)
|
|
|
|
async def list_responses(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 50,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIResponseObject:
|
|
"""
|
|
List responses from the database.
|
|
|
|
:param after: The ID of the last response to return.
|
|
:param limit: The maximum number of responses to return.
|
|
:param model: The model to filter by.
|
|
:param order: The order to sort the responses by.
|
|
"""
|
|
if not order:
|
|
order = Order.desc
|
|
|
|
where_conditions = {}
|
|
if model:
|
|
where_conditions["model"] = model
|
|
|
|
paginated_result = await self.sql_store.fetch_all(
|
|
table="openai_responses",
|
|
where=where_conditions if where_conditions else None,
|
|
order_by=[("created_at", order.value)],
|
|
cursor=("id", after) if after else None,
|
|
limit=limit,
|
|
policy=self.policy,
|
|
)
|
|
|
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
|
return ListOpenAIResponseObject(
|
|
data=data,
|
|
has_more=paginated_result.has_more,
|
|
first_id=data[0].id if data else "",
|
|
last_id=data[-1].id if data else "",
|
|
)
|
|
|
|
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
|
"""
|
|
Get a response object with automatic access control checking.
|
|
"""
|
|
row = await self.sql_store.fetch_one(
|
|
"openai_responses",
|
|
where={"id": response_id},
|
|
policy=self.policy,
|
|
)
|
|
|
|
if not row:
|
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
|
# This provides security by not revealing whether the record exists
|
|
raise ValueError(f"Response with id {response_id} not found") from None
|
|
|
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
|
|
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
|
if not row:
|
|
raise ValueError(f"Response with id {response_id} not found")
|
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
|
return OpenAIDeleteResponseObject(id=response_id)
|
|
|
|
async def list_response_input_items(
|
|
self,
|
|
response_id: str,
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
include: list[str] | None = None,
|
|
limit: int | None = 20,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIResponseInputItem:
|
|
"""
|
|
List input items for a given response.
|
|
|
|
:param response_id: The ID of the response to retrieve input items for.
|
|
:param after: An item ID to list items after, used for pagination.
|
|
:param before: An item ID to list items before, used for pagination.
|
|
:param include: Additional fields to include in the response.
|
|
:param limit: A limit on the number of objects to be returned.
|
|
:param order: The order to return the input items in.
|
|
"""
|
|
if include:
|
|
raise NotImplementedError("Include is not supported yet")
|
|
if before and after:
|
|
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
|
|
|
response_with_input = await self.get_response_object(response_id)
|
|
items = response_with_input.input
|
|
|
|
if order == Order.desc:
|
|
items = list(reversed(items))
|
|
|
|
start_index = 0
|
|
end_index = len(items)
|
|
|
|
if after or before:
|
|
for i, item in enumerate(items):
|
|
item_id = getattr(item, "id", None)
|
|
if after and item_id == after:
|
|
start_index = i + 1
|
|
if before and item_id == before:
|
|
end_index = i
|
|
break
|
|
|
|
if after and start_index == 0:
|
|
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
|
|
if before and end_index == len(items):
|
|
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
|
|
|
|
items = items[start_index:end_index]
|
|
|
|
# Apply limit
|
|
if limit is not None:
|
|
items = items[:limit]
|
|
|
|
return ListOpenAIResponseInputItem(data=items)
|