Merge branch 'main' into add-mcp-streamable-http-support

This commit is contained in:
Calum Murray 2025-07-18 14:38:54 -04:00 committed by GitHub
commit c715f30e65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
247 changed files with 9685 additions and 5249 deletions

View file

@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
if model_id is None:
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
return model
def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
]
if additional_aliases:
aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=aliases,
@ -82,15 +83,43 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can check if a specific
model is currently available for use through dynamic means (e.g., API calls).
This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` - that is handled separately in register_model.
Default implementation returns False (no dynamic models available).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
return False
async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
# Check if model is supported in static configuration
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available dynamically from provider
if not supported_model_id:
if await self.check_model_availability(model.provider_resource_id):
supported_model_id = model.provider_resource_id
else:
# note: we cannot provide a complete list of supported models without
# getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
@ -113,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
# Register the model alias, ensuring it maps to the correct provider model id
self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import mimetypes
import time
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__)
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None
# KV store for persisting OpenAI vector store metadata
kvstore: KVStore | None
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
pass
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
stores: dict[str, dict[str, Any]] = {}
for item in stored_data:
info = json.loads(item)
stores[info["id"]] = info
return stores
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_vector_stores.pop(store_id, None)
@abstractmethod
async def _save_openai_vector_store_file(
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
"""Unregister a vector database (provider-specific implementation)."""
pass
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def insert_chunks(
self,
@ -147,8 +172,9 @@ class OpenAIVectorStoreMixin(ABC):
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
if provider_id is None:
raise ValueError("Provider ID is required")
@ -156,19 +182,19 @@ class OpenAIVectorStoreMixin(ABC):
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
# Embedding dimension is required (defaulted to 384 if not provided)
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
# Register the VectorDB backing this vector store
vector_db = VectorDB(
identifier=store_id,
identifier=vector_db_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
provider_resource_id=vector_db_id,
vector_db_name=name,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
@ -182,11 +208,11 @@ class OpenAIVectorStoreMixin(ABC):
in_progress=0,
total=0,
)
store_info = {
"id": store_id,
store_info: dict[str, Any] = {
"id": vector_db_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"name": name,
"usage_bytes": 0,
"file_counts": file_counts.model_dump(),
"status": status,
@ -206,18 +232,18 @@ class OpenAIVectorStoreMixin(ABC):
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
await self._save_openai_vector_store(vector_db_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
self.openai_vector_stores[vector_db_id] = store_info
# Now that our vector store is created, attach any files that were provided
file_ids = file_ids or []
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
store_info = self.openai_vector_stores[store_id]
store_info = self.openai_vector_stores[vector_db_id]
return VectorStoreObject.model_validate(store_info)
async def openai_list_vector_stores(

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -38,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records."""
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
self.owner = owner
class AuthorizedSqlStore:
@ -71,9 +60,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -91,22 +89,27 @@ class AuthorizedSqlStore:
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user and current_user.attributes:
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data)
@ -136,9 +139,12 @@ class AuthorizedSqlStore:
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
@ -176,43 +182,90 @@ class AuthorizedSqlStore:
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause()
return self._build_default_policy_where_clause(current_user)
else:
return self._build_conservative_where_clause()
def _build_default_policy_where_clause(self) -> str:
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite:
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
current_user = get_authenticated_user()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
base_conditions = self._get_public_access_conditions()
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
@ -222,5 +275,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
engine = create_async_engine(self.config.engine_str)
try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception:
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
class SqlStoreType(StrEnum):
sqlite = "sqlite"
postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)

View file

@ -9,14 +9,12 @@ import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
T = TypeVar("T")
def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value))
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value)
def trace_protocol(cls: type[T]) -> type[T]:
def trace_protocol[T](cls: type[T]) -> type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.