mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
Merge branch 'main' into ragtool-migration
This commit is contained in:
commit
9b27f81b75
12 changed files with 318 additions and 40 deletions
|
@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class InferenceStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
inference_store: SqlStoreConfig | None = Field(
|
||||
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the inference API. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
Configuration for the persistence store used by the inference API. Can be either a
|
||||
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
|
|
|
@ -78,7 +78,10 @@ async def get_auto_router_impl(
|
|||
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference and run_config.inference_store:
|
||||
inference_store = InferenceStore(run_config.inference_store, policy)
|
||||
inference_store = InferenceStore(
|
||||
config=run_config.inference_store,
|
||||
policy=policy,
|
||||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
@ -90,6 +90,11 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
if self.store:
|
||||
try:
|
||||
await self.store.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -160,7 +165,7 @@ class InferenceRouter(Inference):
|
|||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
|
@ -431,7 +436,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
|
@ -537,7 +542,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
|
@ -664,7 +669,7 @@ class InferenceRouter(Inference):
|
|||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
|
@ -710,7 +715,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
@ -806,7 +811,7 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
|
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="inference_store")
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
def __init__(
|
||||
self,
|
||||
config: InferenceStoreConfig | SqlStoreConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, InferenceStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = InferenceStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
self.sql_store_config = sql_store_config
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||
|
@ -42,10 +64,68 @@ class InferenceStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if not self.sql_store:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((chat_completion, input_messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
||||
)
|
||||
await self._queue.put((chat_completion, input_messages))
|
||||
else:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
chat_completion, input_messages = item
|
||||
try:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing chat completion: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
|
|
|
@ -18,6 +18,7 @@ from functools import wraps
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
|
@ -98,7 +99,7 @@ class BackgroundLogger:
|
|||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
@ -118,12 +119,16 @@ class BackgroundLogger:
|
|||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _process_logs(self):
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._process_logs())
|
||||
|
||||
async def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
# figure out how to use a thread's native loop
|
||||
asyncio.run(self.api.log_event(event))
|
||||
await self.api.log_event(event)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
|
@ -136,6 +141,19 @@ class BackgroundLogger:
|
|||
self.log_queue.join()
|
||||
|
||||
|
||||
def enqueue_event(event: Event) -> None:
|
||||
"""Enqueue a telemetry event to the background logger if available.
|
||||
|
||||
This provides a non-blocking path for routers and other hot paths to
|
||||
submit telemetry without awaiting the Telemetry API, reducing contention
|
||||
with the main event loop.
|
||||
"""
|
||||
global BACKGROUND_LOGGER
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
BACKGROUND_LOGGER.log_event(event)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
spans: list[Span] = []
|
||||
|
||||
|
@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler):
|
|||
if span is None:
|
||||
return
|
||||
|
||||
BACKGROUND_LOGGER.log_event(
|
||||
enqueue_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue