Merge branch 'main' of github.com:llamastack/llama-stack

rh-pre-commit.version: 2.3.2
rh-pre-commit.check-secrets: ENABLED
This commit is contained in:
Antony Sallas 2025-10-22 18:22:40 +08:00
commit b469ffce35
334 changed files with 96977 additions and 19023 deletions

View file

@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.schema_utils import json_schema_type
@ -68,10 +68,10 @@ class ShieldsProtocolPrivate(Protocol):
async def unregister_shield(self, identifier: str) -> None: ...
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
class VectorStoresProtocolPrivate(Protocol):
async def register_vector_store(self, vector_store: VectorStore) -> None: ...
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
async def unregister_vector_store(self, vector_store_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol):

View file

@ -83,8 +83,8 @@ class MetaReferenceAgentsImpl(Agents):
self.policy = policy
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
self.persistence_store = await kvstore_impl(self.config.persistence.agent_state)
self.responses_store = ResponsesStore(self.config.persistence.responses, self.policy)
await self.responses_store.initialize()
self.openai_responses_impl = OpenAIResponsesImpl(
inference_api=self.inference_api,

View file

@ -8,24 +8,30 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
class AgentPersistenceConfig(BaseModel):
"""Nested persistence configuration for agents."""
agent_state: KVStoreReference
responses: ResponsesStoreReference
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig
responses_store: SqlStoreConfig
persistence: AgentPersistenceConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="agents_store.db",
),
"responses_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="responses_store.db",
),
"persistence": {
"agent_state": KVStoreReference(
backend="kv_default",
namespace="agents",
).model_dump(exclude_none=True),
"responses": ResponsesStoreReference(
backend="sql_default",
table_name="responses",
).model_dump(exclude_none=True),
}
}

View file

@ -359,6 +359,7 @@ class OpenAIResponsesImpl:
tool_executor=self.tool_executor,
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
instructions=instructions,
)
# Stream the response

View file

@ -110,6 +110,7 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str,
safety_api,
guardrail_ids: list[str] | None = None,
):
@ -133,6 +134,8 @@ class StreamingResponseOrchestrator:
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
# system message that is inserted into the model's context
self.instructions = instructions
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
@ -176,6 +179,7 @@ class StreamingResponseOrchestrator:
tools=self.ctx.available_tools(),
error=error,
usage=self.accumulated_usage,
instructions=self.instructions,
)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:

View file

@ -6,13 +6,13 @@
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
class ReferenceBatchesImplConfig(BaseModel):
"""Configuration for the Reference Batches implementation."""
kvstore: KVStoreConfig = Field(
kvstore: KVStoreReference = Field(
description="Configuration for the key-value store backend.",
)
@ -33,8 +33,8 @@ class ReferenceBatchesImplConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="batches.db",
),
"kvstore": KVStoreReference(
backend="kv_default",
namespace="batches",
).model_dump(exclude_none=True),
}

View file

@ -7,20 +7,17 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
class LocalFSDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="localfs_datasetio.db",
)
"kvstore": KVStoreReference(
backend="kv_default",
namespace="datasetio::localfs",
).model_dump(exclude_none=True)
}

View file

@ -7,20 +7,17 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="meta_reference_eval.db",
)
"kvstore": KVStoreReference(
backend="kv_default",
namespace="eval",
).model_dump(exclude_none=True)
}

View file

@ -8,14 +8,14 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
from llama_stack.core.storage.datatypes import SqlStoreReference
class LocalfsFilesImplConfig(BaseModel):
storage_dir: str = Field(
description="Directory to store uploaded files",
)
metadata_store: SqlStoreConfig = Field(
metadata_store: SqlStoreReference = Field(
description="SQL store configuration for file metadata",
)
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
@ -24,8 +24,8 @@ class LocalfsFilesImplConfig(BaseModel):
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="files_metadata.db",
),
"metadata_store": SqlStoreReference(
backend="sql_default",
table_name="files_metadata",
).model_dump(exclude_none=True),
}

View file

@ -59,7 +59,6 @@ class SentenceTransformersInferenceImpl(
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
model_type=ModelType.embedding,
),

View file

@ -1,75 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import UTC, datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
from llama_stack.log import get_logger
logger = get_logger(name="console_span_processor", category="telemetry")
class ConsoleSpanProcessor(SpanProcessor):
def __init__(self, print_attributes: bool = False):
self.print_attributes = print_attributes
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
span_context += " [bold red][ERROR][/bold red]"
elif span.status.status_code != StatusCode.UNSET:
span_context += f" [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f" ({duration_ms:.2f}ms)"
logger.info(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
if key.startswith("__"):
continue
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
logger.info(f" [dim]{key}[/dim]: {str_value}")
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict) or isinstance(message, list):
message = json.dumps(message, indent=2)
severity_color = {
"error": "red",
"warn": "yellow",
"info": "white",
"debug": "dim",
}.get(severity, "white")
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
logger.info(f"[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import threading
from typing import Any
@ -60,33 +61,38 @@ class TelemetryAdapter(Telemetry):
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
# Since the library client can be recreated multiple times in a notebook,
# the kernel will hold on to the span processor and cause duplicate spans to be written.
if _TRACER_PROVIDER is None:
provider = TracerProvider()
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
if _TRACER_PROVIDER is None:
provider = TracerProvider()
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
# Use single OTLP endpoint for all telemetry signals
# Use single OTLP endpoint for all telemetry signals
# Let OpenTelemetry SDK handle endpoint construction automatically
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
# Let OpenTelemetry SDK handle endpoint construction automatically
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
span_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
metric_provider = MeterProvider(metric_readers=[metric_reader])
metrics.set_meter_provider(metric_provider)
self.is_otel_endpoint_set = True
else:
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
self.is_otel_endpoint_set = False
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
if self.is_otel_endpoint_set:
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):

View file

@ -12,15 +12,8 @@ from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,14 +8,14 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
persistence: KVStoreReference = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(
@ -23,8 +23,8 @@ class ChromaVectorIOConfig(BaseModel):
) -> dict[str, Any]:
return {
"db_path": db_path,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_inline_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::chroma",
).model_dump(exclude_none=True),
}

View file

@ -16,11 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,22 +8,19 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig
persistence: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="faiss_store.db",
)
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::faiss",
).model_dump(exclude_none=True)
}

View file

@ -17,34 +17,21 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorDBsProtocolPrivate,
)
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import FaissVectorIOConfig
logger = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
@ -155,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
await self._save_index()
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
scores = []
@ -175,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError(
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
)
@ -199,35 +176,28 @@ class FaissIndex(EmbeddingIndex):
)
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
# Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
@ -252,45 +222,33 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(
key=key,
value=vector_db.model_dump_json(),
)
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
# Store in cache
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api,
)
async def list_vector_dbs(self) -> list[VectorDB]:
return [i.vector_db for i in self.cache.values()]
async def list_vector_stores(self) -> list[VectorStore]:
return [i.vector_store for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
assert self.kvstore is not None
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
if vector_store_id not in self.cache:
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = self.cache.get(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
@ -298,10 +256,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = self.cache.get(vector_db_id)
if index is None:

View file

@ -14,11 +14,6 @@ from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,25 +8,22 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MilvusVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::milvus",
).model_dump(exclude_none=True),
}

View file

@ -15,11 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -9,23 +9,21 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class QdrantVectorIOConfig(BaseModel):
path: str
kvstore: KVStoreConfig
persistence: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::qdrant",
).model_dump(exclude_none=True),
}

View file

@ -15,11 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,22 +8,19 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
class SQLiteVectorIOConfig(BaseModel):
db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::sqlite_vec",
).model_dump(exclude_none=True),
}

View file

@ -17,15 +17,10 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -33,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
@ -46,7 +41,7 @@ HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
@ -175,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
# Insert vector embeddings
embedding_data = [
(
(
chunk.chunk_id,
serialize_vector(emb.tolist()),
)
)
((chunk.chunk_id, serialize_vector(emb.tolist())))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
cur.executemany(
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data,
)
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
# Insert FTS content
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data],
)
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
# INSERT new entries
cur.executemany(
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data,
)
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
connection.commit()
@ -216,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Run batch insertion in a background thread
await asyncio.to_thread(_execute_all_batch_inserts)
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs vector-based search using a virtual table for vector similarity.
"""
@ -261,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
@ -403,41 +374,32 @@ class SQLiteVecIndex(EmbeddingIndex):
await asyncio.to_thread(_delete_chunks)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
This class handles vector database registration (with metadata stored in a table named `vector_stores`)
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(
self,
config,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.vector_db_store = None
self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_store_table = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_stores:
vector_store = VectorStore.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
@ -446,67 +408,64 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def list_vector_stores(self) -> list[VectorStore]:
return [v.vector_store for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index = VectorStoreWithIndex(
vector_store=vector_store,
index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension,
dimension=vector_store.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_db.identifier,
bank_id=vector_store.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id not in self.cache:
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks.
await index.insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)

View file

@ -7,20 +7,17 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
class HuggingfaceDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig
kvstore: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="huggingface_datasetio.db",
)
"kvstore": KVStoreReference(
backend="kv_default",
namespace="datasetio::huggingface",
).model_dump(exclude_none=True)
}

View file

@ -20,7 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service.
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
from llama_stack.core.storage.datatypes import SqlStoreReference
class S3FilesImplConfig(BaseModel):
@ -24,7 +24,7 @@ class S3FilesImplConfig(BaseModel):
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
@ -35,8 +35,8 @@ class S3FilesImplConfig(BaseModel):
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
"metadata_store": SqlStoreReference(
backend="sql_default",
table_name="s3_files_metadata",
).model_dump(exclude_none=True),
}

View file

@ -18,7 +18,7 @@ This provider enables running inference using NVIDIA NIM.
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client
@ -45,7 +45,7 @@ The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "system",
@ -67,37 +67,40 @@ print(f"Response: {response.choices[0].message.content}")
The following example shows how to do tool calling for an NVIDIA NIM.
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
tool_definition = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"description": "Temperature unit (celsius or fahrenheit)",
"default": "celsius",
},
},
"required": ["location"],
},
},
)
}
tool_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.choices[0].message.content}")
print(f"Response content: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Tool Called: {tool_call.function.name}")
print(f"Arguments: {tool_call.function.arguments}")
```
### Structured Output Example
@ -105,33 +108,26 @@ if tool_response.choices[0].message.tool_calls:
The following example shows how to do structured output for an NVIDIA NIM.
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"age": {"type": "number"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
extra_body={"nvext": {"guided_json": person_schema}},
)
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
@ -141,7 +137,7 @@ The following example shows how to create embeddings for an NVIDIA NIM.
```python
response = client.embeddings.create(
model="nvidia/llama-3.2-nv-embedqa-1b-v2",
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
input=["What is the capital of France?"],
extra_body={"input_type": "query"},
)
@ -163,15 +159,15 @@ image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.chat.completions.create(
model="nvidia/vila",
model="nvidia/meta/llama-3.2-11b-vision-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"image": {
"data": demo_image_b64,
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{demo_image_b64}",
},
},
{

View file

@ -10,7 +10,7 @@ from .config import NVIDIAConfig
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
from .nvidia import NVIDIAInferenceAdapter
if not isinstance(config, NVIDIAConfig):

View file

@ -19,15 +19,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html

View file

@ -22,7 +22,7 @@ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -19,7 +19,7 @@ This provider enables safety checks and guardrails for LLM interactions using NV
Build the NVIDIA environment:
```bash
llama stack build --distro nvidia --image-type venv
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
```
### Basic Usage using the LlamaStack Python Client

View file

@ -12,11 +12,6 @@ from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -12,24 +12,16 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -38,7 +30,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=ids,
)
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
)
)
distances = results["distances"][0]
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -133,26 +114,24 @@ class ChromaIndex(EmbeddingIndex):
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference,
models_apis: Api.models,
inference_api: Inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.models_api = models_apis
self.client = None
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.vector_db_store = self.kvstore
self.kvstore = await kvstore_impl(self.config.persistence)
self.vector_store_table = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
@ -172,70 +151,58 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id not in self.cache:
log.warning(f"Vector DB {vector_store_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_db_id))
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_store_id] = index
return index
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")

View file

@ -8,21 +8,21 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
url: str | None
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
persistence: KVStoreReference = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(cls, __distro_dir__: str, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
return {
"url": url,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_remote_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::chroma_remote",
).model_dump(exclude_none=True),
}

View file

@ -13,12 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
persistence: KVStoreReference = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -28,8 +28,8 @@ class MilvusVectorIOConfig(BaseModel):
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::milvus_remote",
).model_dump(exclude_none=True),
}

View file

@ -14,15 +14,10 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -31,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -40,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
@ -74,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
# Add BM25 function for full-text search
bm25_function = Function(
@ -144,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
@ -167,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
@ -210,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
@ -303,12 +261,11 @@ class MilvusIndex(EmbeddingIndex):
raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -316,29 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.client = None
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
collection_name=vector_store.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
@ -355,72 +311,61 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps[Api.models], deps.get(Api.files, None))
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -22,7 +19,9 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
persistence: KVStoreReference | None = Field(
description="Config for KV store backend (SQLite only for now)", default=None
)
@classmethod
def sample_run_config(
@ -41,8 +40,8 @@ class PGVectorVectorIOConfig(BaseModel):
"db": db,
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="pgvector_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::pgvector",
).model_dump(exclude_none=True),
}

View file

@ -16,26 +16,15 @@ from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig
@ -43,7 +32,7 @@ from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
@ -90,13 +79,13 @@ class PGVectorIndex(EmbeddingIndex):
def __init__(
self,
vector_db: VectorDB,
vector_store: VectorStore,
dimension: int,
conn: psycopg2.extensions.connection,
kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.vector_db = vector_db
self.vector_store = vector_store
self.dimension = dimension
self.conn = conn
self.kvstore = kvstore
@ -108,9 +97,9 @@ class PGVectorIndex(EmbeddingIndex):
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
# SQL doesn't allow hyphens in table names, and vector_store.identifier may contain hyphens
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
sanitized_identifier = sanitize_collection_name(self.vector_store.identifier)
self.table_name = f"vs_{sanitized_identifier}"
cur.execute(
@ -133,8 +122,8 @@ class PGVectorIndex(EmbeddingIndex):
"""
)
except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
log.exception(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}") from e
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@ -205,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
@ -317,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
def get_pgvector_search_function(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
@ -339,26 +323,21 @@ class PGVectorIndex(EmbeddingIndex):
)
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.conn = None
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
await self.initialize_openai_vector_stores()
try:
@ -396,71 +375,59 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
# Persist vector DB metadata in the KV store
assert self.kvstore is not None
# Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
# Create and cache the PGVector index table for the vector DB
pgvector_index = PGVectorIndex(
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorDBWithIndex(
vector_db,
index=pgvector_index,
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
self.cache[vector_store.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
# Remove provider index and cache
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
# Delete vector DB metadata from KV store
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize()
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)

View file

@ -12,11 +12,6 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -27,14 +24,14 @@ class QdrantVectorIOConfig(BaseModel):
prefix: str | None = None
timeout: int | None = None
host: str | None = None
kvstore: KVStoreConfig
persistence: KVStoreReference
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${env.QDRANT_API_KEY:=}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="qdrant_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::qdrant_remote",
).model_dump(exclude_none=True),
}

View file

@ -16,8 +16,6 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
@ -25,16 +23,13 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -43,7 +38,7 @@ CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:qdrant:{VERSION}::"
def convert_id(_id: str) -> str:
@ -99,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=chunk_ids),
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
@ -133,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
@ -156,12 +145,11 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -169,27 +157,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self.vector_store_table = None
self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None:
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"})
client_config = self.config.model_dump(exclude_none=True, exclude={"persistence"})
self.client = AsyncQdrantClient(**client_config)
self.kvstore = await kvstore_impl(self.config.kvstore)
self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
for vector_store_data in stored_vector_stores:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
@ -197,68 +182,57 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.cache[vector_store.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
assert self.kvstore is not None
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB not found {vector_db_id}")
if self.vector_store_table is None:
raise ValueError(f"Vector DB not found {vector_store_id}")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -279,7 +253,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")

View file

@ -12,11 +12,6 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -8,10 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@ -19,19 +16,17 @@ from llama_stack.schema_utils import json_schema_type
class WeaviateVectorIOConfig(BaseModel):
weaviate_api_key: str | None = Field(description="The API key for the Weaviate instance", default=None)
weaviate_cluster_url: str | None = Field(description="The URL of the Weaviate cluster", default="localhost:8080")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
persistence: KVStoreReference | None = Field(
description="Config for KV store backend (SQLite only for now)", default=None
)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
**kwargs: Any,
) -> dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"weaviate_api_key": None,
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="weaviate_registry.db",
),
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::weaviate",
).model_dump(exclude_none=True),
}

View file

@ -16,22 +16,19 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -40,7 +37,7 @@ from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
@ -48,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class WeaviateIndex(EmbeddingIndex):
def __init__(
self,
client: weaviate.WeaviateClient,
collection_name: str,
kvstore: KVStore | None = None,
):
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
self.kvstore = kvstore
@ -108,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
try:
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
)
except Exception as e:
log.error(f"Weaviate client vector search failed: {e}")
@ -153,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs BM25-based keyword search using Weaviate's built-in full-text search.
Args:
@ -175,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
# Perform BM25 keyword search on chunk_content field
try:
results = collection.query.bm25(
query=query_string,
limit=k,
return_metadata=wvc.query.MetadataQuery(score=True),
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
)
except Exception as e:
log.error(f"Weaviate client keyword search failed: {e}")
@ -274,26 +257,14 @@ class WeaviateIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter(
OpenAIVectorStoreMixin,
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.client_cache = {}
self.cache = {}
self.vector_db_store = None
self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient:
@ -301,10 +272,7 @@ class WeaviateVectorIOAdapter(
log.info("Using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
host=host,
port=port,
)
client = weaviate.connect_to_local(host=host, port=port)
else:
log.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
@ -320,8 +288,8 @@ class WeaviateVectorIOAdapter(
async def initialize(self) -> None:
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
# Initialize KV store for metadata if configured
if self.config.kvstore is not None:
self.kvstore = await kvstore_impl(self.config.kvstore)
if self.config.persistence is not None:
self.kvstore = await kvstore_impl(self.config.persistence)
else:
self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts")
@ -332,17 +300,11 @@ class WeaviateVectorIOAdapter(
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
vector_store = VectorStore.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(
client=client,
collection_name=vector_db.identifier,
kvstore=self.kvstore,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store, index=idx, inference_api=self.inference_api
)
# Load OpenAI vector stores metadata into cache
@ -354,90 +316,74 @@ class WeaviateVectorIOAdapter(
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
# Create collection if it doesn't exist
if not client.collections.exists(sanitized_collection_name):
client.collections.create(
name=sanitized_collection_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
],
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
sanitized_collection_name = sanitize_collection_name(vector_store_id, weaviate_format=True)
if vector_store_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
return
client.collections.delete(sanitized_collection_name)
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
if not client.collections.exists(sanitized_collection_name):
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
index = VectorStoreWithIndex(
vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
self.cache[vector_store_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
index = await self._get_and_cache_vector_db_index(store_id)
index = await self._get_and_cache_vector_store_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")

View file

@ -6,9 +6,12 @@
import asyncio
import base64
import platform
import struct
from typing import TYPE_CHECKING
import torch
from llama_stack.log import get_logger
if TYPE_CHECKING:
@ -24,6 +27,8 @@ from llama_stack.apis.inference import (
EMBEDDING_MODELS = {}
DARWIN = "Darwin"
log = get_logger(name=__name__, category="providers::utils")
@ -83,6 +88,13 @@ class SentenceTransformerEmbeddingMixin:
def _load_model():
from sentence_transformers import SentenceTransformer
platform_name = platform.system()
if platform_name == DARWIN:
# PyTorch's OpenMP kernels can segfault on macOS when spawned from background
# threads with the default parallel settings, so force a single-threaded CPU run.
log.debug(f"Constraining torch threads on {platform_name} to a single worker")
torch.set_num_threads(1)
return SentenceTransformer(model, trust_remote_code=True)
loaded_model = await asyncio.to_thread(_load_model)

View file

@ -15,12 +15,13 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
logger = get_logger(name=__name__, category="inference")
@ -28,33 +29,32 @@ logger = get_logger(name=__name__, category="inference")
class InferenceStore:
def __init__(
self,
config: InferenceStoreConfig | SqlStoreConfig,
reference: InferenceStoreReference,
policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, InferenceStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = InferenceStoreConfig(
sql_store_config=config,
)
self.config = config
self.sql_store_config = config.sql_store_config
self.reference = reference
self.sql_store = None
self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
self._max_write_queue_size: int = reference.max_write_queue_size
self._num_writers: int = max(1, reference.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite to avoid concurrency issues
backend_name = self.reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
await self.sql_store.create_table(
"chat_completions",
{

View file

@ -4,143 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from enum import Enum
from typing import Annotated, Literal
from typing import Annotated
from pydantic import BaseModel, Field, field_validator
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
postgres = "postgres"
mongodb = "mongodb"
class CommonConfig(BaseModel):
namespace: str | None = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
class RedisKVStoreConfig(CommonConfig):
type: Literal["redis"] = KVStoreType.redis.value
host: str = "localhost"
port: int = 6379
@property
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@classmethod
def pip_packages(cls) -> list[str]:
return ["redis"]
@classmethod
def sample_run_config(cls):
return {
"type": "redis",
"host": "${env.REDIS_HOST:=localhost}",
"port": "${env.REDIS_PORT:=6379}",
}
class SqliteKVStoreConfig(CommonConfig):
type: Literal["sqlite"] = KVStoreType.sqlite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
description="File path for the sqlite database",
)
@classmethod
def pip_packages(cls) -> list[str]:
return ["aiosqlite"]
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
"type": "sqlite",
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
class PostgresKVStoreConfig(CommonConfig):
type: Literal["postgres"] = KVStoreType.postgres.value
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
user: str
password: str | None = None
ssl_mode: str | None = None
ca_cert_path: str | None = None
table_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
return {
"type": "postgres",
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
}
@classmethod
@field_validator("table_name")
def validate_table_name(cls, v: str) -> str:
# PostgreSQL identifiers rules:
# - Must start with a letter or underscore
# - Can contain letters, numbers, and underscores
# - Maximum length is 63 bytes
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
if not re.match(pattern, v):
raise ValueError(
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
)
if len(v) > 63:
raise ValueError("Table name must be less than 63 characters")
return v
@classmethod
def pip_packages(cls) -> list[str]:
return ["psycopg2-binary"]
class MongoDBKVStoreConfig(CommonConfig):
type: Literal["mongodb"] = KVStoreType.mongodb.value
host: str = "localhost"
port: int = 27017
db: str = "llamastack"
user: str | None = None
password: str | None = None
collection_name: str = "llamastack_kvstore"
@classmethod
def pip_packages(cls) -> list[str]:
return ["pymongo"]
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {
"type": "mongodb",
"host": "${env.MONGODB_HOST:=localhost}",
"port": "${env.MONGODB_PORT:=5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
}
from pydantic import Field
from llama_stack.core.storage.datatypes import (
MongoDBKVStoreConfig,
PostgresKVStoreConfig,
RedisKVStoreConfig,
SqliteKVStoreConfig,
StorageBackendType,
)
KVStoreConfig = Annotated[
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig,
Field(discriminator="type", default=KVStoreType.sqlite.value),
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
]
@ -148,13 +25,13 @@ def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
"""Get pip packages for KV store config, handling both dict and object cases."""
if isinstance(store_config, dict):
store_type = store_config.get("type")
if store_type == "sqlite":
if store_type == StorageBackendType.KV_SQLITE.value:
return SqliteKVStoreConfig.pip_packages()
elif store_type == "postgres":
elif store_type == StorageBackendType.KV_POSTGRES.value:
return PostgresKVStoreConfig.pip_packages()
elif store_type == "redis":
elif store_type == StorageBackendType.KV_REDIS.value:
return RedisKVStoreConfig.pip_packages()
elif store_type == "mongodb":
elif store_type == StorageBackendType.KV_MONGODB.value:
return MongoDBKVStoreConfig.pip_packages()
else:
raise ValueError(f"Unknown KV store type: {store_type}")

View file

@ -4,9 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from __future__ import annotations
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
from .api import KVStore
from .config import KVStoreConfig, KVStoreType
from .config import KVStoreConfig
def kvstore_dependencies():
@ -44,20 +52,41 @@ class InmemoryKVStoreImpl(KVStore):
del self._store[key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available KV store backends for reference resolution."""
global _KVSTORE_BACKENDS
_KVSTORE_BACKENDS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
backend_name = reference.backend
backend_config = _KVSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
config = backend_config.model_copy()
config.namespace = reference.namespace
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == KVStoreType.sqlite.value:
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == KVStoreType.postgres.value:
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == KVStoreType.mongodb.value:
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
impl = MongoDBKVStoreImpl(config)

View file

@ -17,8 +17,6 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.models import Model, Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
@ -44,6 +42,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
@ -64,7 +63,7 @@ MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
@ -81,13 +80,14 @@ class OpenAIVectorStoreMixin(ABC):
# Implementing classes should call super().__init__() in their __init__ method
# to properly initialize the mixin attributes.
def __init__(
self, files_api: Files | None = None, kvstore: KVStore | None = None, models_api: Models | None = None
self,
files_api: Files | None = None,
kvstore: KVStore | None = None,
):
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.kvstore = kvstore
self.models_api = models_api
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
@ -321,12 +321,12 @@ class OpenAIVectorStoreMixin(ABC):
pass
@abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None:
async def register_vector_store(self, vector_store: VectorStore) -> None:
"""Register a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def unregister_vector_db(self, vector_db_id: str) -> None:
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Unregister a vector database (provider-specific implementation)."""
pass
@ -358,7 +358,7 @@ class OpenAIVectorStoreMixin(ABC):
extra_body = params.model_extra or {}
metadata = params.metadata or {}
provider_vector_db_id = extra_body.get("provider_vector_db_id")
provider_vector_store_id = extra_body.get("provider_vector_store_id")
# Use embedding info from metadata if available, otherwise from extra_body
if metadata.get("embedding_model"):
@ -370,16 +370,6 @@ class OpenAIVectorStoreMixin(ABC):
logger.debug(
f"Using embedding config from metadata (takes precedence over extra_body): model='{embedding_model}', dimension={embedding_dimension}"
)
# Check for conflicts with extra_body
if extra_body.get("embedding_model") and extra_body["embedding_model"] != embedding_model:
raise ValueError(
f"Embedding model inconsistent between metadata ('{embedding_model}') and extra_body ('{extra_body['embedding_model']}')"
)
if extra_body.get("embedding_dimension") and extra_body["embedding_dimension"] != embedding_dimension:
raise ValueError(
f"Embedding dimension inconsistent between metadata ({embedding_dimension}) and extra_body ({extra_body['embedding_dimension']})"
)
else:
embedding_model = extra_body.get("embedding_model")
embedding_dimension = extra_body.get("embedding_dimension", EMBEDDING_DIMENSION)
@ -389,42 +379,29 @@ class OpenAIVectorStoreMixin(ABC):
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
# Derive the canonical vector_store_id (allow override, else generate)
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if embedding_model is None:
result = await self._get_default_embedding_model_and_dimension()
if result is None:
raise ValueError(
"embedding_model is required in extra_body when creating a vector store. "
"No default embedding model could be determined automatically."
)
embedding_model, embedding_dimension = result
elif embedding_dimension is None:
# Embedding model was provided but dimension wasn't, look it up
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
if embedding_dimension is None:
raise ValueError(
f"Could not determine embedding dimension for model '{embedding_model}'. "
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
)
raise ValueError("embedding_model is required")
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
# Register the VectorDB backing this vector store
# Register the VectorStore backing this vector store
if provider_id is None:
raise ValueError("Provider ID is required but was not provided")
vector_db = VectorDB(
identifier=vector_db_id,
# call to the provider to create any index, etc.
vector_store = VectorStore(
identifier=vector_store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=vector_db_id,
vector_db_name=params.name,
provider_resource_id=vector_store_id,
vector_store_name=params.name,
)
await self.register_vector_db(vector_db)
await self.register_vector_store(vector_store)
# Create OpenAI vector store metadata
status = "completed"
@ -438,7 +415,7 @@ class OpenAIVectorStoreMixin(ABC):
total=0,
)
store_info: dict[str, Any] = {
"id": vector_db_id,
"id": vector_store_id,
"object": "vector_store",
"created_at": created_at,
"name": params.name,
@ -455,104 +432,25 @@ class OpenAIVectorStoreMixin(ABC):
# Add provider information to metadata if provided
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
if provider_vector_store_id:
metadata["provider_vector_store_id"] = provider_vector_store_id
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(vector_db_id, store_info)
await self._save_openai_vector_store(vector_store_id, store_info)
# Store in memory cache
self.openai_vector_stores[vector_db_id] = store_info
self.openai_vector_stores[vector_store_id] = store_info
# Now that our vector store is created, attach any files that were provided
file_ids = params.file_ids or []
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
store_info = self.openai_vector_stores[vector_db_id]
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject.model_validate(store_info)
async def _get_embedding_models(self) -> list[Model]:
"""Get list of embedding models from the models API."""
if not self.models_api:
return []
models_response = await self.models_api.list_models()
models_list = models_response.data if hasattr(models_response, "data") else models_response
embedding_models = []
for model in models_list:
if not isinstance(model, Model):
logger.warning(f"Non-Model object found in models list: {type(model)} - {model}")
continue
if model.model_type == "embedding":
embedding_models.append(model)
return embedding_models
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
"""Get embedding dimension for a specific model by looking it up in the models API.
Args:
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
Returns:
The embedding dimension for the model, or None if not found
"""
embedding_models = await self._get_embedding_models()
for model in embedding_models:
# Check for exact match first
if model.identifier == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
else:
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
return None
# Check for prefixed/unprefixed variations
# If model_id is unprefixed, check if it matches the resource_id
if model.provider_resource_id == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
return None
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
"""Get default embedding model from the models API.
Looks for embedding models marked with default_configured=True in metadata.
Returns None if no default embedding model is found.
Raises ValueError if multiple defaults are found.
"""
embedding_models = await self._get_embedding_models()
default_models = []
for model in embedding_models:
if model.metadata.get("default_configured") is True:
default_models.append(model.identifier)
if len(default_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_models}. "
"Only one embedding model can be marked as default."
)
if default_models:
model_id = default_models[0]
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
return model_id, embedding_dimension
logger.debug("No default embedding models found")
return None
async def openai_list_vector_stores(
self,
limit: int | None = 20,
@ -660,7 +558,7 @@ class OpenAIVectorStoreMixin(ABC):
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
await self.unregister_vector_store(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")

View file

@ -23,8 +23,8 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
@ -187,7 +187,7 @@ def make_overlapped_chunks(
updated_timestamp=int(time.time()),
chunk_window=chunk_window,
chunk_tokenizer=default_tokenizer,
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
content_token_count=len(toks),
metadata_token_count=len(metadata_tokens),
)
@ -255,8 +255,8 @@ class EmbeddingIndex(ABC):
@dataclass
class VectorDBWithIndex:
vector_db: VectorDB
class VectorStoreWithIndex:
vector_store: VectorStore
index: EmbeddingIndex
inference_api: Api.inference
@ -269,14 +269,14 @@ class VectorDBWithIndex:
if c.embedding is None:
chunks_to_embed.append(c)
if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
if chunks_to_embed:
params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_db.embedding_model,
model=self.vector_store.embedding_model,
input=[c.content for c in chunks_to_embed],
)
resp = await self.inference_api.openai_embeddings(params)
@ -319,7 +319,7 @@ class VectorDBWithIndex:
return await self.index.query_keyword(query_string, k, score_threshold)
params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_db.embedding_model,
model=self.vector_store.embedding_model,
input=[query_string],
)
embeddings_response = await self.inference_api.openai_embeddings(params)

View file

@ -18,13 +18,13 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectWithInput,
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
logger = get_logger(name=__name__, category="openai_responses")
@ -45,39 +45,38 @@ class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
class ResponsesStore:
def __init__(
self,
config: ResponsesStoreConfig | SqlStoreConfig,
reference: ResponsesStoreReference | SqlStoreReference,
policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, ResponsesStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = ResponsesStoreConfig(
sql_store_config=config,
)
if isinstance(reference, ResponsesStoreReference):
self.reference = reference
else:
self.reference = ResponsesStoreReference(**reference.model_dump())
self.config = config
self.sql_store_config = config.sql_store_config
if not self.sql_store_config:
self.sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store = None
self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
self.sql_store = None
self.enable_write_queue = True
# Async write queue and worker control
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
self._max_write_queue_size: int = self.reference.max_write_queue_size
self._num_writers: int = max(1, self.reference.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
await self.sql_store.create_table(
"openai_responses",
{

View file

@ -12,10 +12,10 @@ from llama_stack.core.access_control.conditions import ProtectedResource
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.storage.datatypes import StorageBackendType
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="providers::utils")
@ -82,8 +82,8 @@ class AuthorizedSqlStore:
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
self.database_type = self.sql_store.config.type.value
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
@ -220,9 +220,9 @@ class AuthorizedSqlStore:
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
@ -237,9 +237,9 @@ class AuthorizedSqlStore:
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
@ -248,10 +248,10 @@ class AuthorizedSqlStore:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres:
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite:
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")

View file

@ -26,10 +26,10 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.elements import ColumnElement
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="providers::utils")

View file

@ -4,90 +4,28 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
from typing import Annotated, cast
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.storage.datatypes import (
PostgresSqlStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageBackendConfig,
StorageBackendType,
)
from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(StrEnum):
sqlite = "sqlite"
postgres = "postgres"
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@classmethod
def pip_packages(cls) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
)
@property
def engine_str(self) -> str:
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return {
"type": "sqlite",
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["aiosqlite"]
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@classmethod
def pip_packages(cls) -> list[str]:
return super().pip_packages() + ["asyncpg"]
@classmethod
def sample_run_config(cls, **kwargs):
return {
"type": "postgres",
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
}
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
SqlStoreConfig = Annotated[
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
Field(discriminator="type", default=SqlStoreType.sqlite.value),
Field(discriminator="type"),
]
@ -95,9 +33,9 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
"""Get pip packages for SQL store config, handling both dict and object cases."""
if isinstance(store_config, dict):
store_type = store_config.get("type")
if store_type == "sqlite":
if store_type == StorageBackendType.SQL_SQLITE.value:
return SqliteSqlStoreConfig.pip_packages()
elif store_type == "postgres":
elif store_type == StorageBackendType.SQL_POSTGRES.value:
return PostgresSqlStoreConfig.pip_packages()
else:
raise ValueError(f"Unknown SQL store type: {store_type}")
@ -105,12 +43,28 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
return store_config.pip_packages()
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
backend_name = reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {config.type}")
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
return impl
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
_SQLSTORE_BACKENDS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg

View file

@ -70,7 +70,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
"__args__": json.dumps(combined_args),
}
return class_name, method_name, span_attributes
@ -82,8 +82,8 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
count = 0
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1