mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
# What does this PR do? These aren't controllable by LLAMA_STACK_LOGGING ``` tests/integration/agents/test_persistence.py::test_delete_agents_and_sessions SKIPPED (This ...) [ 3%] tests/integration/agents/test_persistence.py::test_get_agent_turns_and_steps SKIPPED (This t...) [ 7%] tests/integration/agents/test_openai_responses.py::test_responses_store[openai_client-txt=openai/gpt-4o-tools0-True] instantiating llama_stack_client WARNING 2025-10-02 13:14:33,472 root:258 uncategorized: Unknown logging category: testing. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,477 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,960 root:258 uncategorized: Unknown logging category: tokenizer_utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,962 root:258 uncategorized: Unknown logging category: models::llama. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,963 root:258 uncategorized: Unknown logging category: models::llama. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,968 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,974 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:33,978 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,350 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,366 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,489 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,490 root:258 uncategorized: Unknown logging category: inference_store. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,697 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:35,918 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 INFO 2025-10-02 13:14:35,945 llama_stack.providers.utils.inference.inference_store:74 inference_store: Write queue disabled for SQLite to avoid concurrency issues WARNING 2025-10-02 13:14:36,172 root:258 uncategorized: Unknown logging category: files. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,218 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,219 root:258 uncategorized: Unknown logging category: vector_io. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,231 root:258 uncategorized: Unknown logging category: vector_io. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,255 root:258 uncategorized: Unknown logging category: tool_runtime. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,486 root:258 uncategorized: Unknown logging category: responses_store. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,503 root:258 uncategorized: Unknown logging category: openai::responses. Falling back to default 'root' level: 20 INFO 2025-10-02 13:14:36,524 llama_stack.providers.utils.responses.responses_store:80 responses_store: Write queue disabled for SQLite to avoid concurrency issues WARNING 2025-10-02 13:14:36,528 root:258 uncategorized: Unknown logging category: providers::utils. Falling back to default 'root' level: 20 WARNING 2025-10-02 13:14:36,703 root:258 uncategorized: Unknown logging category: uncategorized. Falling back to default 'root' level: 20 ``` ## Test Plan
244 lines
9.1 KiB
Python
244 lines
9.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from llama_stack.apis.inference import (
|
|
ListOpenAIChatCompletionResponse,
|
|
OpenAIChatCompletion,
|
|
OpenAICompletionWithInputMessages,
|
|
OpenAIMessageParam,
|
|
Order,
|
|
)
|
|
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
|
from llama_stack.log import get_logger
|
|
|
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
|
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
|
|
|
logger = get_logger(name=__name__, category="inference")
|
|
|
|
|
|
class InferenceStore:
|
|
def __init__(
|
|
self,
|
|
config: InferenceStoreConfig | SqlStoreConfig,
|
|
policy: list[AccessRule],
|
|
):
|
|
# Handle backward compatibility
|
|
if not isinstance(config, InferenceStoreConfig):
|
|
# Legacy: SqlStoreConfig passed directly as config
|
|
config = InferenceStoreConfig(
|
|
sql_store_config=config,
|
|
)
|
|
|
|
self.config = config
|
|
self.sql_store_config = config.sql_store_config
|
|
self.sql_store = None
|
|
self.policy = policy
|
|
|
|
# Disable write queue for SQLite to avoid concurrency issues
|
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
|
|
|
# Async write queue and worker control
|
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
|
self._max_write_queue_size: int = config.max_write_queue_size
|
|
self._num_writers: int = max(1, config.num_writers)
|
|
|
|
async def initialize(self):
|
|
"""Create the necessary tables if they don't exist."""
|
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
|
await self.sql_store.create_table(
|
|
"chat_completions",
|
|
{
|
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
|
"created": ColumnType.INTEGER,
|
|
"model": ColumnType.STRING,
|
|
"choices": ColumnType.JSON,
|
|
"input_messages": ColumnType.JSON,
|
|
},
|
|
)
|
|
|
|
if self.enable_write_queue:
|
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
for _ in range(self._num_writers):
|
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
else:
|
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
|
|
|
async def shutdown(self) -> None:
|
|
if not self._worker_tasks:
|
|
return
|
|
if self._queue is not None:
|
|
await self._queue.join()
|
|
for t in self._worker_tasks:
|
|
if not t.done():
|
|
t.cancel()
|
|
for t in self._worker_tasks:
|
|
try:
|
|
await t
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._worker_tasks.clear()
|
|
|
|
async def flush(self) -> None:
|
|
"""Wait for all queued writes to complete. Useful for testing."""
|
|
if self.enable_write_queue and self._queue is not None:
|
|
await self._queue.join()
|
|
|
|
async def store_chat_completion(
|
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
|
) -> None:
|
|
if self.enable_write_queue:
|
|
if self._queue is None:
|
|
raise ValueError("Inference store is not initialized")
|
|
try:
|
|
self._queue.put_nowait((chat_completion, input_messages))
|
|
except asyncio.QueueFull:
|
|
logger.warning(
|
|
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
|
)
|
|
await self._queue.put((chat_completion, input_messages))
|
|
else:
|
|
await self._write_chat_completion(chat_completion, input_messages)
|
|
|
|
async def _worker_loop(self) -> None:
|
|
assert self._queue is not None
|
|
while True:
|
|
try:
|
|
item = await self._queue.get()
|
|
except asyncio.CancelledError:
|
|
break
|
|
chat_completion, input_messages = item
|
|
try:
|
|
await self._write_chat_completion(chat_completion, input_messages)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.error(f"Error writing chat completion: {e}")
|
|
finally:
|
|
self._queue.task_done()
|
|
|
|
async def _write_chat_completion(
|
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
|
) -> None:
|
|
if self.sql_store is None:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
data = chat_completion.model_dump()
|
|
record_data = {
|
|
"id": data["id"],
|
|
"created": data["created"],
|
|
"model": data["model"],
|
|
"choices": data["choices"],
|
|
"input_messages": [message.model_dump() for message in input_messages],
|
|
}
|
|
|
|
try:
|
|
await self.sql_store.insert(
|
|
table="chat_completions",
|
|
data=record_data,
|
|
)
|
|
except IntegrityError as e:
|
|
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
|
|
# recorded responses across different tests. No need to warn or error under those circumstances.
|
|
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
|
|
|
|
# Check if it's a unique constraint violation
|
|
error_message = str(e.orig) if e.orig else str(e)
|
|
if self._is_unique_constraint_error(error_message):
|
|
# Update the existing record instead
|
|
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
|
else:
|
|
# Re-raise if it's not a unique constraint error
|
|
raise
|
|
|
|
def _is_unique_constraint_error(self, error_message: str) -> bool:
|
|
"""Check if the error is specifically a unique constraint violation."""
|
|
error_lower = error_message.lower()
|
|
return any(
|
|
indicator in error_lower
|
|
for indicator in [
|
|
"unique constraint failed", # SQLite
|
|
"duplicate key", # PostgreSQL
|
|
"unique violation", # PostgreSQL alternative
|
|
"duplicate entry", # MySQL
|
|
]
|
|
)
|
|
|
|
async def list_chat_completions(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 50,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIChatCompletionResponse:
|
|
"""
|
|
List chat completions from the database.
|
|
|
|
:param after: The ID of the last chat completion to return.
|
|
:param limit: The maximum number of chat completions to return.
|
|
:param model: The model to filter by.
|
|
:param order: The order to sort the chat completions by.
|
|
"""
|
|
if not self.sql_store:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
if not order:
|
|
order = Order.desc
|
|
|
|
where_conditions = {}
|
|
if model:
|
|
where_conditions["model"] = model
|
|
|
|
paginated_result = await self.sql_store.fetch_all(
|
|
table="chat_completions",
|
|
where=where_conditions if where_conditions else None,
|
|
order_by=[("created", order.value)],
|
|
cursor=("id", after) if after else None,
|
|
limit=limit,
|
|
)
|
|
|
|
data = [
|
|
OpenAICompletionWithInputMessages(
|
|
id=row["id"],
|
|
created=row["created"],
|
|
model=row["model"],
|
|
choices=row["choices"],
|
|
input_messages=row["input_messages"],
|
|
)
|
|
for row in paginated_result.data
|
|
]
|
|
return ListOpenAIChatCompletionResponse(
|
|
data=data,
|
|
has_more=paginated_result.has_more,
|
|
first_id=data[0].id if data else "",
|
|
last_id=data[-1].id if data else "",
|
|
)
|
|
|
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
if not self.sql_store:
|
|
raise ValueError("Inference store is not initialized")
|
|
|
|
row = await self.sql_store.fetch_one(
|
|
table="chat_completions",
|
|
where={"id": completion_id},
|
|
)
|
|
|
|
if not row:
|
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
|
# This provides security by not revealing whether the record exists
|
|
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
|
|
|
return OpenAICompletionWithInputMessages(
|
|
id=row["id"],
|
|
created=row["created"],
|
|
model=row["model"],
|
|
choices=row["choices"],
|
|
input_messages=row["input_messages"],
|
|
)
|