mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into feat/gunicorn-production-server
This commit is contained in:
commit
b728307427
332 changed files with 50191 additions and 68996 deletions
|
|
@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
|
|||
:object: The object type, which will be "model"
|
||||
:created: The Unix timestamp in seconds when the model was created
|
||||
:owned_by: The owner of the model
|
||||
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
custom_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
|
|
@ -113,7 +115,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .synthetic_data_generation import *
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function.
|
||||
|
||||
:cvar none: No filtering applied, accept all generated synthetic data
|
||||
:cvar random: Random sampling of generated data points
|
||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
top_k = "top_k"
|
||||
top_p = "top_p"
|
||||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
"""
|
||||
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
model: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
||||
|
||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
||||
"""
|
||||
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
||||
"""
|
||||
...
|
||||
|
|
@ -31,6 +31,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import LoggingConfig, get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -132,8 +133,14 @@ class StackRun(Subcommand):
|
|||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
|
||||
else:
|
||||
config = {}
|
||||
provider = Provider(
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
provider_id=provider_type.split("::")[1],
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
|
|
|
|||
|
|
@ -134,6 +134,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
custom_metadata={
|
||||
"model_type": model.model_type,
|
||||
"provider_id": model.provider_id,
|
||||
"provider_resource_id": model.provider_resource_id,
|
||||
**model.metadata,
|
||||
},
|
||||
)
|
||||
for model in all_models
|
||||
]
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
|
|
@ -66,7 +65,6 @@ class LlamaStack(
|
|||
Agents,
|
||||
Batches,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
PostTraining,
|
||||
VectorIO,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from llama_stack.core.ui.modules.api import llama_stack_api
|
|||
def models():
|
||||
# Models Section
|
||||
st.header("Models")
|
||||
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
|
||||
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
|
||||
|
||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||
st.json(models_info[selected_model])
|
||||
|
|
|
|||
|
|
@ -12,7 +12,11 @@ from llama_stack.core.ui.modules.api import llama_stack_api
|
|||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
available_models = [
|
||||
model.id
|
||||
for model in available_models
|
||||
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
|
||||
]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
|
|
|
|||
|
|
@ -1015,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
|
|||
if url.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(url)
|
||||
resp = r.text
|
||||
resp: str = r.text
|
||||
return resp
|
||||
raise ValueError(f"Unexpected URL: {type(url)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
|
|
@ -42,3 +43,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
|
|||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||
"""
|
||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
VertexAI doesn't currently offer a way to query a list of available models from Google's Model Garden
|
||||
For now we return a hardcoded version of the available models
|
||||
|
||||
:return: An iterable of model IDs
|
||||
"""
|
||||
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class InferenceStore:
|
|||
self.reference = reference
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
|
|
@ -47,14 +48,13 @@ class InferenceStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
backend_name = self.reference.backend
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
|
@ -70,8 +70,9 @@ class InferenceStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||
allowed_models: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
|
|
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
for provider_model_id in provider_models_ids:
|
||||
if not isinstance(provider_model_id, str):
|
||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||
continue
|
||||
model = self.construct_model_from_identifier(provider_model_id)
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ class ResponsesStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
|
@ -99,8 +99,9 @@ class ResponsesStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
|||
String,
|
||||
Table,
|
||||
Text,
|
||||
event,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
|
|
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||
|
||||
engine = create_async_engine(
|
||||
self.config.engine_str,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
# Enable Write-Ahead Logging for better concurrency
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
|
|||
}
|
||||
|
||||
# Include test_id for isolation, except for shared infrastructure endpoints
|
||||
if parsed.path not in ("/api/tags", "/v1/models"):
|
||||
if parsed.path not in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||
normalized["test_id"] = test_id
|
||||
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
|
|
@ -430,7 +430,7 @@ class ResponseStorage:
|
|||
|
||||
# For model-list endpoints, include digest in filename to distinguish different model sets
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{request_hash}-{digest}.json"
|
||||
|
||||
|
|
@ -554,13 +554,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
|||
Supported endpoints:
|
||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||
- '/v1/openai/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||
Returns a list of unique identifiers or None if structure doesn't match.
|
||||
"""
|
||||
if "models" in response["body"]:
|
||||
# ollama
|
||||
items = response["body"]["models"]
|
||||
else:
|
||||
# openai
|
||||
# openai or openai-style endpoints
|
||||
items = response["body"]
|
||||
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||
return sorted(set(idents))
|
||||
|
|
@ -581,7 +582,7 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
if endpoint == "/v1/models":
|
||||
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||
for m in body:
|
||||
key = m.id
|
||||
seen[key] = m
|
||||
|
|
@ -665,7 +666,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
logger.info(f" Test context: {get_test_context()}")
|
||||
|
||||
if mode == APIRecordingMode.LIVE or storage is None:
|
||||
if endpoint == "/v1/models":
|
||||
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
|
@ -699,7 +700,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
recording = None
|
||||
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
|
||||
# Special handling for model-list endpoints: merge all recordings with this hash
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
|
||||
records = storage._model_list_responses(request_hash)
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
|
|
@ -739,13 +740,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
)
|
||||
|
||||
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
|
||||
if endpoint == "/v1/models":
|
||||
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue