mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 08:02:27 +00:00
Merge branch 'main' into add-mcp-streamable-http-support
This commit is contained in:
commit
c715f30e65
247 changed files with 9685 additions and 5249 deletions
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if model_id is None:
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
return model
|
||||
|
||||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
|
|||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
aliases = [alias for alias in aliases if alias is not None]
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=aliases,
|
||||
|
|
@ -82,15 +83,43 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider (non-static check).
|
||||
|
||||
This is for subclassing purposes, so providers can check if a specific
|
||||
model is currently available for use through dynamic means (e.g., API calls).
|
||||
|
||||
This method should NOT check statically configured model entries in
|
||||
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||
|
||||
Default implementation returns False (no dynamic models available).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
# Check if model is supported in static configuration
|
||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
# If not found in static config, check if it's available dynamically from provider
|
||||
if not supported_model_id:
|
||||
if await self.check_model_availability(model.provider_resource_id):
|
||||
supported_model_id = model.provider_resource_id
|
||||
else:
|
||||
# note: we cannot provide a complete list of supported models without
|
||||
# getting a complete list from the provider, so we return "..."
|
||||
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id:
|
||||
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
|
||||
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
|
|
@ -113,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
# Register the model alias, ensuring it maps to the correct provider model id
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
# KV store for persisting OpenAI vector store metadata
|
||||
kvstore: KVStore | None
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
# update in-memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
@abstractmethod
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
stores: dict[str, dict[str, Any]] = {}
|
||||
for item in stored_data:
|
||||
info = json.loads(item)
|
||||
stores[info["id"]] = info
|
||||
return stores
|
||||
|
||||
@abstractmethod
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
# update in-memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
@abstractmethod
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
# remove from in-memory cache
|
||||
self.openai_vector_stores.pop(store_id, None)
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store_file(
|
||||
|
|
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
async def initialize_openai_vector_stores(self) -> None:
|
||||
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
@abstractmethod
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -147,8 +172,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store."""
|
||||
store_id = name or str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
|
@ -156,19 +182,19 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if embedding_model is None:
|
||||
raise ValueError("Embedding model is required")
|
||||
|
||||
# Use provided embedding dimension or default to 384
|
||||
# Embedding dimension is required (defaulted to 384 if not provided)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
provider_vector_db_id = provider_vector_db_id or store_id
|
||||
# Register the VectorDB backing this vector store
|
||||
vector_db = VectorDB(
|
||||
identifier=store_id,
|
||||
identifier=vector_db_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_vector_db_id,
|
||||
provider_resource_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
# Register the vector DB
|
||||
await self.register_vector_db(vector_db)
|
||||
|
||||
# Create OpenAI vector store metadata
|
||||
|
|
@ -182,11 +208,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
in_progress=0,
|
||||
total=0,
|
||||
)
|
||||
store_info = {
|
||||
"id": store_id,
|
||||
store_info: dict[str, Any] = {
|
||||
"id": vector_db_id,
|
||||
"object": "vector_store",
|
||||
"created_at": created_at,
|
||||
"name": store_id,
|
||||
"name": name,
|
||||
"usage_bytes": 0,
|
||||
"file_counts": file_counts.model_dump(),
|
||||
"status": status,
|
||||
|
|
@ -206,18 +232,18 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
store_info["metadata"] = metadata
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store(store_id, store_info)
|
||||
await self._save_openai_vector_store(vector_db_id, store_info)
|
||||
|
||||
# Store in memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
self.openai_vector_stores[vector_db_id] = store_info
|
||||
|
||||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[store_id]
|
||||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
|
|||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from .sqlstore import SqlStoreType
|
||||
|
||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||
|
||||
|
|
@ -38,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
|
|||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
"""Simple ProtectedResource implementation for SQL records."""
|
||||
|
||||
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
self.identifier = record_id
|
||||
|
||||
if access_attributes:
|
||||
self.owner = User(
|
||||
principal="system",
|
||||
attributes=access_attributes,
|
||||
)
|
||||
else:
|
||||
self.owner = User(
|
||||
principal="system_public",
|
||||
attributes=None,
|
||||
)
|
||||
self.owner = owner
|
||||
|
||||
|
||||
class AuthorizedSqlStore:
|
||||
|
|
@ -71,9 +60,18 @@ class AuthorizedSqlStore:
|
|||
:param sql_store: Base SqlStore implementation to wrap
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
def _detect_database_type(self) -> None:
|
||||
"""Detect the database type from the underlying SQL store."""
|
||||
if not hasattr(self.sql_store, "config"):
|
||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type
|
||||
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||
|
||||
|
|
@ -91,22 +89,27 @@ class AuthorizedSqlStore:
|
|||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""Create a table with built-in access control support."""
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
|
||||
enhanced_schema = dict(schema)
|
||||
if "access_attributes" not in enhanced_schema:
|
||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||
if "owner_principal" not in enhanced_schema:
|
||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||
|
||||
await self.sql_store.create_table(table, enhanced_schema)
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""Insert a row with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user and current_user.attributes:
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
|
@ -136,9 +139,12 @@ class AuthorizedSqlStore:
|
|||
|
||||
for row in rows.data:
|
||||
stored_access_attrs = row.get("access_attributes")
|
||||
stored_owner_principal = row.get("owner_principal") or ""
|
||||
|
||||
record_id = row.get("id", "unknown")
|
||||
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
|
||||
sql_record = SqlRecord(
|
||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
|
@ -176,43 +182,90 @@ class AuthorizedSqlStore:
|
|||
Only applies SQL filtering for the default policy to ensure correctness.
|
||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||
return self._build_default_policy_where_clause()
|
||||
return self._build_default_policy_where_clause(current_user)
|
||||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
def _build_default_policy_where_clause(self) -> str:
|
||||
def _json_extract(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value (keeping JSON type).
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _json_extract_text(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value as text.
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value as text
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
# Postgres stores JSON null as 'null'
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||
This means user must match ALL attribute categories that exist in the resource.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not current_user or not current_user.attributes:
|
||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||
else:
|
||||
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
|
||||
|
||||
user_attr_conditions = []
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
user_attr_conditions = []
|
||||
|
||||
if current_user and current_user.attributes:
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
for value in user_values:
|
||||
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
|
||||
# Check if JSON array contains the value
|
||||
escaped_value = value.replace("'", "''")
|
||||
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||
|
||||
if value_conditions:
|
||||
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
|
||||
# Check if the category is missing (NULL)
|
||||
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||
|
||||
if user_attr_conditions:
|
||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||
base_conditions.append(all_requirements_met)
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
else:
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
def _build_conservative_where_clause(self) -> str:
|
||||
"""Conservative SQL filtering for custom policies.
|
||||
|
|
@ -222,5 +275,8 @@ class AuthorizedSqlStore:
|
|||
current_user = get_authenticated_user()
|
||||
|
||||
if not current_user:
|
||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||
# Only allow public records
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
return "1=1"
|
||||
|
|
|
|||
|
|
@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
engine = create_async_engine(self.config.engine_str)
|
||||
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
if column_name in column_names:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
|
||||
def check_column_exists(sync_conn):
|
||||
inspector = inspect(sync_conn)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return False, False # table doesn't exist, column doesn't exist
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
return True, column_name in column_names # table exists, column exists or not
|
||||
|
||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||
if not table_exists or column_exists:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,9 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
|
|
@ -19,7 +18,7 @@ from .api import SqlStore
|
|||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
|
||||
class SqlStoreType(Enum):
|
||||
class SqlStoreType(StrEnum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
|||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
|
|
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
|
|
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
|||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ import inspect
|
|||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> Primitive:
|
||||
return str(_prepare_for_json(value))
|
||||
|
|
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
|
|||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol(cls: type[T]) -> type[T]:
|
||||
def trace_protocol[T](cls: type[T]) -> type[T]:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue