mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 7s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 23s
Test External API and Providers / test-external (venv) (push) Failing after 30s
UI Tests / ui-tests (22) (push) Successful in 35s
Pre-commit / pre-commit (push) Successful in 1m45s
# What does this PR do? Duplicate chat completion IDs can be generated during tests especially if they are replaying recorded responses across different tests. No need to warn or error under those circumstances. In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
246 lines
9.2 KiB
Python
246 lines
9.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from llama_stack.apis.inference import (
|
|
ListOpenAIChatCompletionResponse,
|
|
OpenAIChatCompletion,
|
|
OpenAICompletionWithInputMessages,
|
|
OpenAIMessageParam,
|
|
Order,
|
|
)
|
|
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
|
from llama_stack.log import get_logger
|
|
|
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
|
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
|
|
|
logger = get_logger(name=__name__, category="inference_store")
|
|
|
|
|
|
class InferenceStore:
|
|
def __init__(
|
|
self,
|
|
config: InferenceStoreConfig | SqlStoreConfig,
|
|
policy: list[AccessRule],
|
|
):
|
|
# Handle backward compatibility
|
|
if not isinstance(config, InferenceStoreConfig):
|
|
# Legacy: SqlStoreConfig passed directly as config
|
|
config = InferenceStoreConfig(
|
|
sql_store_config=config,
|
|
)
|
|
|
|
self.config = config
|
|
self.sql_store_config = config.sql_store_config
|
|
self.sql_store = None
|
|
self.policy = policy
|
|
|
|
# Disable write queue for SQLite to avoid concurrency issues
|
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
|
|
|
# Async write queue and worker control
|
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
|
self._max_write_queue_size: int = config.max_write_queue_size
|
|
self._num_writers: int = max(1, config.num_writers)
|
|
|
|
async def initialize(self):
|
|
"""Create the necessary tables if they don't exist."""
|
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
|
await self.sql_store.create_table(
|
|
"chat_completions",
|
|
{
|
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
|
"created": ColumnType.INTEGER,
|
|
"model": ColumnType.STRING,
|
|
"choices": ColumnType.JSON,
|
|
"input_messages": ColumnType.JSON,
|
|
},
|
|
)
|
|
|
|
if self.enable_write_queue:
|
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
for _ in range(self._num_writers):
|
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
else:
|
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
|
|
|
async def shutdown(self) -> None:
|
|
if not self._worker_tasks:
|
|
return
|
|
if self._queue is not None:
|
|
await self._queue.join()
|
|
for t in self._worker_tasks:
|
|
if not t.done():
|
|
t.cancel()
|
|
for t in self._worker_tasks:
|
|
try:
|
|
await t
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._worker_tasks.clear()
|
|
|
|
async def flush(self) -> None:
|
|
"""Wait for all queued writes to complete. Useful for testing."""
|
|
if self.enable_write_queue and self._queue is not None:
|
|
await self._queue.join()
|
|
|
|
async def store_chat_completion(
|
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
|
) -> None:
|
|
if self.enable_write_queue:
|
|
if self._queue is None:
|
|
raise ValueError("Inference store is not initialized")
|
|
try:
|
|
self._queue.put_nowait((chat_completion, input_messages))
|
|
except asyncio.QueueFull:
|
|
logger.warning(
|
|
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
|
)
|
|
await self._queue.put((chat_completion, input_messages))
|
|
else:
|
|
await self._write_chat_completion(chat_completion, input_messages)
|
|
|
|
async def _worker_loop(self) -> None:
|
|
assert self._queue is not None
|
|
while True:
|
|
try:
|
|
item = await self._queue.get()
|
|
except asyncio.CancelledError:
|
|
break
|
|
chat_completion, input_messages = item
|
|
try:
|
|
await self._write_chat_completion(chat_completion, input_messages)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.error(f"Error writing chat completion: {e}")
|
|
finally:
|
|
self._queue.task_done()
|
|
|
|
async def _write_chat_completion(
|
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
|
) -> None:
|
|
if self.sql_store is None:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
data = chat_completion.model_dump()
|
|
record_data = {
|
|
"id": data["id"],
|
|
"created": data["created"],
|
|
"model": data["model"],
|
|
"choices": data["choices"],
|
|
"input_messages": [message.model_dump() for message in input_messages],
|
|
}
|
|
|
|
try:
|
|
await self.sql_store.insert(
|
|
table="chat_completions",
|
|
data=record_data,
|
|
)
|
|
except IntegrityError as e:
|
|
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
|
|
# recorded responses across different tests. No need to warn or error under those circumstances.
|
|
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
|
|
|
|
# Check if it's a unique constraint violation
|
|
error_message = str(e.orig) if e.orig else str(e)
|
|
if self._is_unique_constraint_error(error_message):
|
|
# Update the existing record instead
|
|
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
|
else:
|
|
# Re-raise if it's not a unique constraint error
|
|
raise
|
|
|
|
def _is_unique_constraint_error(self, error_message: str) -> bool:
|
|
"""Check if the error is specifically a unique constraint violation."""
|
|
error_lower = error_message.lower()
|
|
return any(
|
|
indicator in error_lower
|
|
for indicator in [
|
|
"unique constraint failed", # SQLite
|
|
"duplicate key", # PostgreSQL
|
|
"unique violation", # PostgreSQL alternative
|
|
"duplicate entry", # MySQL
|
|
]
|
|
)
|
|
|
|
async def list_chat_completions(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 50,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIChatCompletionResponse:
|
|
"""
|
|
List chat completions from the database.
|
|
|
|
:param after: The ID of the last chat completion to return.
|
|
:param limit: The maximum number of chat completions to return.
|
|
:param model: The model to filter by.
|
|
:param order: The order to sort the chat completions by.
|
|
"""
|
|
if not self.sql_store:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
if not order:
|
|
order = Order.desc
|
|
|
|
where_conditions = {}
|
|
if model:
|
|
where_conditions["model"] = model
|
|
|
|
paginated_result = await self.sql_store.fetch_all(
|
|
table="chat_completions",
|
|
where=where_conditions if where_conditions else None,
|
|
order_by=[("created", order.value)],
|
|
cursor=("id", after) if after else None,
|
|
limit=limit,
|
|
policy=self.policy,
|
|
)
|
|
|
|
data = [
|
|
OpenAICompletionWithInputMessages(
|
|
id=row["id"],
|
|
created=row["created"],
|
|
model=row["model"],
|
|
choices=row["choices"],
|
|
input_messages=row["input_messages"],
|
|
)
|
|
for row in paginated_result.data
|
|
]
|
|
return ListOpenAIChatCompletionResponse(
|
|
data=data,
|
|
has_more=paginated_result.has_more,
|
|
first_id=data[0].id if data else "",
|
|
last_id=data[-1].id if data else "",
|
|
)
|
|
|
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
if not self.sql_store:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
row = await self.sql_store.fetch_one(
|
|
table="chat_completions",
|
|
where={"id": completion_id},
|
|
policy=self.policy,
|
|
)
|
|
|
|
if not row:
|
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
|
# This provides security by not revealing whether the record exists
|
|
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
|
|
|
return OpenAICompletionWithInputMessages(
|
|
id=row["id"],
|
|
created=row["created"],
|
|
model=row["model"],
|
|
choices=row["choices"],
|
|
input_messages=row["input_messages"],
|
|
)
|