Merge branch 'main' into feat/gunicorn-production-server

This commit is contained in:
Ashwin Bharambe 2025-11-03 17:39:30 -08:00 committed by GitHub
commit b728307427
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
332 changed files with 50191 additions and 68996 deletions

View file

@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
custom_metadata: dict[str, Any] | None = None
class OpenAIListModelsResponse(BaseModel):
@ -113,7 +115,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import *

View file

@ -1,77 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):
"""The type of filtering function.
:cvar none: No filtering applied, accept all generated synthetic data
:cvar random: Random sampling of generated data points
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
:cvar sigmoid: Apply sigmoid function for probability-based filtering
"""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
"""
dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: str | None = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
:param statistics: (Optional) Statistical information about the generation process and filtering results
"""
synthetic_data: list[dict[str, Any]]
statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: str | None = None,
) -> SyntheticDataGenerationResponse:
"""Generate synthetic data based on input dialogs and apply filtering.
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
:returns: Response containing filtered synthetic data samples and optional statistics
"""
...

View file

@ -31,6 +31,7 @@ from llama_stack.core.storage.datatypes import (
)
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import LoggingConfig, get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -132,8 +133,14 @@ class StackRun(Subcommand):
)
sys.exit(1)
if provider_type in providers_for_api:
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
else:
config = {}
provider = Provider(
provider_type=provider_type,
config=config,
provider_id=provider_type.split("::")[1],
)
provider_list.setdefault(api, []).append(provider)

View file

@ -134,6 +134,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
object="model",
created=int(time.time()),
owned_by="llama_stack",
custom_metadata={
"model_type": model.model_type,
"provider_id": model.provider_id,
"provider_resource_id": model.provider_resource_id,
**model.metadata,
},
)
for model in all_models
]

View file

@ -31,7 +31,6 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
@ -66,7 +65,6 @@ class LlamaStack(
Agents,
Batches,
Safety,
SyntheticDataGeneration,
Datasets,
PostTraining,
VectorIO,

View file

@ -12,7 +12,7 @@ from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -12,7 +12,11 @@ from llama_stack.core.ui.modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
available_models = [
model.id
for model in available_models
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,

View file

@ -1015,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(url)
resp = r.text
resp: str = r.text
return resp
raise ValueError(f"Unexpected URL: {type(url)}")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
import google.auth.transport.requests
from google.auth import default
@ -42,3 +43,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def list_provider_model_ids(self) -> Iterable[str]:
"""
VertexAI doesn't currently offer a way to query a list of available models from Google's Model Garden
For now we return a hardcoded version of the available models
:return: An iterable of model IDs
"""
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]

View file

@ -35,6 +35,7 @@ class InferenceStore:
self.reference = reference
self.sql_store = None
self.policy = policy
self.enable_write_queue = True
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
@ -47,14 +48,13 @@ class InferenceStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite to avoid concurrency issues
backend_name = self.reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"chat_completions",
{
@ -70,8 +70,9 @@ class InferenceStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)

View file

@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
model = self.construct_model_from_identifier(provider_model_id)

View file

@ -70,13 +70,13 @@ class ResponsesStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if backend_config.type == StorageBackendType.SQL_SQLITE:
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"openai_responses",
{
@ -99,8 +99,9 @@ class ResponsesStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -17,6 +17,7 @@ from sqlalchemy import (
String,
Table,
Text,
event,
inspect,
select,
text,
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
engine = create_async_engine(
self.config.engine_str,
pool_pre_ping=True,
connect_args=connect_args,
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
# Enable Write-Ahead Logging for better concurrency
cursor.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to 5 seconds (retry instead of immediate failure)
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
cursor.execute("PRAGMA busy_timeout=5000")
# Use NORMAL synchronous mode for better performance (still safe with WAL)
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
return engine
async def create_table(
self,

View file

@ -156,7 +156,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
if parsed.path not in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
normalized["test_id"] = test_id
normalized_json = json.dumps(normalized, sort_keys=True)
@ -430,7 +430,7 @@ class ResponseStorage:
# For model-list endpoints, include digest in filename to distinguish different model sets
endpoint = request.get("endpoint")
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
digest = _model_identifiers_digest(endpoint, response)
response_file = f"models-{request_hash}-{digest}.json"
@ -554,13 +554,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
Supported endpoints:
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
- '/v1/openai/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match.
"""
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
# openai or openai-style endpoints
items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents))
@ -581,7 +582,7 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
seen: dict[str, dict[str, Any]] = {}
for rec in records:
body = rec["response"]["body"]
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
for m in body:
key = m.id
seen[key] = m
@ -665,7 +666,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
logger.info(f" Test context: {get_test_context()}")
if mode == APIRecordingMode.LIVE or storage is None:
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
@ -699,7 +700,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
@ -739,13 +740,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = [m async for m in response]
request_data = {