Merge branch 'main' into ragtool-migration

This commit is contained in:
Francisco Arceo 2025-09-10 13:58:57 -06:00 committed by GitHub
commit 9b27f81b75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 318 additions and 40 deletions

View file

@ -3,6 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
logger = get_logger(name=__name__, category="inference_store")
class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
def __init__(
self,
config: InferenceStoreConfig | SqlStoreConfig,
policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, InferenceStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = InferenceStoreConfig(
sql_store_config=config,
)
self.sql_store_config = sql_store_config
self.config = config
self.sql_store_config = config.sql_store_config
self.sql_store = None
self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@ -42,10 +64,68 @@ class InferenceStore:
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if not self.sql_store:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:
self._queue.put_nowait((chat_completion, input_messages))
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
else:
await self._write_chat_completion(chat_completion, input_messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
self._queue.task_done()
async def _write_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.sql_store is None:
raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump()

View file

@ -18,6 +18,7 @@ from functools import wraps
from typing import Any
from llama_stack.apis.telemetry import (
Event,
LogSeverity,
Span,
SpanEndPayload,
@ -98,7 +99,7 @@ class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0
@ -118,12 +119,16 @@ class BackgroundLogger:
self._last_queue_full_log_time = current_time
self._dropped_since_last_notice = 0
def _process_logs(self):
def _worker(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_logs())
async def _process_logs(self):
while True:
try:
event = self.log_queue.get()
# figure out how to use a thread's native loop
asyncio.run(self.api.log_event(event))
await self.api.log_event(event)
except Exception:
import traceback
@ -136,6 +141,19 @@ class BackgroundLogger:
self.log_queue.join()
def enqueue_event(event: Event) -> None:
"""Enqueue a telemetry event to the background logger if available.
This provides a non-blocking path for routers and other hot paths to
submit telemetry without awaiting the Telemetry API, reducing contention
with the main event loop.
"""
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
BACKGROUND_LOGGER.log_event(event)
class TraceContext:
spans: list[Span] = []
@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler):
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler):
if span is None:
return
BACKGROUND_LOGGER.log_event(
enqueue_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,