chore: Updating how default embedding model is set in stack (#3818)

# What does this PR do?

Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).

New config is simply (default for Starter distro):

```yaml
vector_stores:
  default_provider_id: faiss
  default_embedding_model:
    provider_id: sentence-transformers
    model_id: nomic-ai/nomic-embed-text-v1.5
```

## Test Plan
CI and Unit tests.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-20 17:22:45 -04:00 committed by GitHub
parent 2c43285e22
commit 48581bf651
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 973 additions and 818 deletions

View file

@ -169,9 +169,7 @@ jobs:
run: |
uv run --no-sync \
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \
--embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
--embedding-dimension 768
tests/integration/vector_io
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}

View file

@ -88,18 +88,19 @@ Llama Stack provides OpenAI-compatible RAG capabilities through:
To enable automatic vector store creation without specifying embedding models, configure a default embedding model in your run.yaml like so:
```yaml
models:
- model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: inline::sentence-transformers
metadata:
embedding_dimension: 768
default_configured: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
```
With this configuration:
- `client.vector_stores.create()` works without requiring embedding model parameters
- The system automatically uses the default model and its embedding dimension for any newly created vector store
- Only one model can be marked as `default_configured: true`
- `client.vector_stores.create()` works without requiring embedding model or provider parameters
- The system automatically uses the default vector store provider (`faiss`) when multiple providers are available
- The system automatically uses the default embedding model (`sentence-transformers/nomic-ai/nomic-embed-text-v1.5`) for any newly created vector store
- The `default_provider_id` specifies which vector storage backend to use
- The `default_embedding_model` specifies both the inference provider and model for embeddings
## Vector Store Operations
@ -108,14 +109,15 @@ With this configuration:
You can create vector stores with automatic or explicit embedding model selection:
```python
# Automatic - uses default configured embedding model
# Automatic - uses default configured embedding model and vector store provider
vs = client.vector_stores.create()
# Explicit - specify embedding model when you need a specific one
# Explicit - specify embedding model and/or provider when you need specific ones
vs = client.vector_stores.create(
extra_body={
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
"embedding_dimension": 768
"provider_id": "faiss", # Optional: specify vector store provider
"embedding_model": "sentence-transformers/nomic-ai/nomic-embed-text-v1.5",
"embedding_dimension": 768 # Optional: will be auto-detected if not provided
}
)
```

View file

@ -121,6 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
models = "models"
shields = "shields"
vector_dbs = "vector_dbs" # only used for routing
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
@ -59,3 +59,35 @@ class ListVectorDBsResponse(BaseModel):
"""
data: list[VectorDB]
@runtime_checkable
class VectorDBs(Protocol):
"""Internal protocol for vector_dbs routing - no public API endpoints."""
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""Internal method to list vector databases."""
...
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Internal method to get a vector database by ID."""
...
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Internal method to register a vector database."""
...
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Internal method to unregister a vector database."""
...

View file

@ -354,6 +354,26 @@ class AuthenticationRequiredError(Exception):
pass
class QualifiedModel(BaseModel):
"""A qualified model identifier, consisting of a provider ID and a model ID."""
provider_id: str
model_id: str
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
default_provider_id: str | None = Field(
default=None,
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
)
default_embedding_model: QualifiedModel | None = Field(
default=None,
description="Default embedding model configuration for vector stores.",
)
class QuotaPeriod(StrEnum):
DAY = "day"
@ -499,6 +519,11 @@ can be instantiated multiple times (with different configs) if necessary.
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
vector_stores: VectorStoresConfig | None = Field(
default=None,
description="Configuration for vector stores, including default embedding model",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):

View file

@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
]

View file

@ -29,6 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
@ -81,6 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,

View file

@ -29,6 +29,7 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"models": ModelsRoutingTable,
@ -37,6 +38,7 @@ async def get_routing_table_impl(
"scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
"vector_dbs": VectorDBsRoutingTable,
}
if api.value not in api_to_tables:
@ -91,6 +93,9 @@ async def get_auto_router_impl(
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
def __init__(
self,
routing_table: RoutingTable,
vector_stores_config: VectorStoresConfig | None = None,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
self.vector_stores_config = vector_stores_config
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
@ -122,6 +125,17 @@ class VectorIORouter(VectorIO):
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
# Use default embedding model if not specified
if (
embedding_model is None
and self.vector_stores_config
and self.vector_stores_config.default_embedding_model is not None
):
# Construct the full model ID with provider prefix
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
model_id = self.vector_stores_config.default_embedding_model.model_id
embedding_model = f"{embedding_provider_id}/{model_id}"
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
@ -132,11 +146,24 @@ class VectorIORouter(VectorIO):
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
# Use default configured provider
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
default_provider = self.vector_stores_config.default_provider_id
if default_provider in available_providers:
provider_id = default_provider
logger.debug(f"Using configured default vector store provider: {provider_id}")
else:
raise ValueError(
f"Configured default vector store provider '{default_provider}' not found. "
f"Available providers: {available_providers}"
)
else:
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
@ -243,8 +270,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store(vector_store_id)
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,

View file

@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):

View file

@ -0,0 +1,323 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
# Removed VectorDBs import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
OpenAICreateVectorStoreRequestWithExtraBody,
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl):
"""Internal routing table for vector_db operations.
Does not inherit from VectorDBs to avoid exposing public API endpoints.
Only provides internal routing functionality for VectorIORouter.
"""
# Internal methods only - no public API exposure
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> Any:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
try:
provider = self.impls_by_provider_id[provider_id]
except KeyError:
available_providers = list(self.impls_by_provider_id.keys())
raise ValueError(
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
) from None
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
request = OpenAICreateVectorStoreRequestWithExtraBody(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store = await provider.openai_create_vector_store(request)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def unregister_vector_db(self, vector_store_id: str) -> None:
"""Remove the vector store from the routing table registry."""
try:
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
if vector_db_obj:
await self.unregister_object(vector_db_obj)
except Exception as e:
# Log the error but don't fail the operation
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@ -108,30 +108,6 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def validate_default_embedding_model(impls: dict[Api, Any]):
"""Validate that at most one embedding model is marked as default."""
if Api.models not in impls:
return
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response
default_embedding_models = []
for model in models_list:
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
default_embedding_models.append(model.identifier)
if len(default_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
"Only one embedding model can be marked as default."
)
if default_embedding_models:
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -162,7 +138,41 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
await validate_default_embedding_model(impls)
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
"""Validate vector stores configuration."""
if vector_stores_config is None:
return
default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
embedding_dimension = default_model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
class EnvVarError(Exception):
@ -400,8 +410,8 @@ class Stack:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
self.impls = impls
def create_registry_refresh_task(self):

View file

@ -25,6 +25,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:
- provider_type: inline::localfs
safety:

View file

@ -128,6 +128,21 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -253,3 +268,8 @@ server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5

View file

@ -26,6 +26,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:
- provider_type: inline::localfs
safety:

View file

@ -128,6 +128,21 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -256,3 +271,8 @@ server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5

View file

@ -26,6 +26,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:
- provider_type: inline::localfs
safety:

View file

@ -128,6 +128,21 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -253,3 +268,8 @@ server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5

View file

@ -11,8 +11,10 @@ from llama_stack.core.datatypes import (
BuildProvider,
Provider,
ProviderSpec,
QualifiedModel,
ShieldInput,
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
@ -31,6 +33,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
@ -113,6 +117,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
BuildProvider(provider_type="inline::milvus"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
BuildProvider(provider_type="remote::qdrant"),
BuildProvider(provider_type="remote::weaviate"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [
@ -221,12 +227,35 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
password="${env.PGVECTOR_PASSWORD:=}",
),
),
Provider(
provider_id="${env.QDRANT_URL:+qdrant}",
provider_type="remote::qdrant",
config=QdrantVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
url="${env.QDRANT_URL:=}",
),
),
Provider(
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
),
),
],
"files": [files_provider],
},
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
vector_stores_config=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="sentence-transformers",
model_id="nomic-ai/nomic-embed-text-v1.5",
),
),
),
},
run_config_env_vars={

View file

@ -27,6 +27,7 @@ from llama_stack.core.datatypes import (
ShieldInput,
TelemetryConfig,
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.storage.datatypes import (
@ -186,6 +187,7 @@ class RunConfigSettings(BaseModel):
default_tool_groups: list[ToolGroupInput] | None = None
default_datasets: list[DatasetInput] | None = None
default_benchmarks: list[BenchmarkInput] | None = None
vector_stores_config: VectorStoresConfig | None = None
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
storage_backends: dict[str, Any] | None = None
storage_stores: dict[str, Any] | None = None
@ -263,7 +265,7 @@ class RunConfigSettings(BaseModel):
)
# Return a dict that matches StackRunConfig structure
return {
config = {
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
"image_name": name,
"container_image": container_image,
@ -283,6 +285,11 @@ class RunConfigSettings(BaseModel):
"telemetry": self.telemetry.model_dump(exclude_none=True) if self.telemetry else None,
}
if self.vector_stores_config:
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
return config
class DistributionTemplate(BaseModel):
"""

View file

@ -59,7 +59,6 @@ class SentenceTransformersInferenceImpl(
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
model_type=ModelType.embedding,
),

View file

@ -12,15 +12,8 @@ from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -16,11 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -17,27 +17,14 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorDBsProtocolPrivate,
)
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import FaissVectorIOConfig
@ -155,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
await self._save_index()
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
scores = []
@ -175,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError(
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
)
@ -200,17 +177,10 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
async def initialize(self) -> None:
@ -252,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(
key=key,
value=vector_db.model_dump_json(),
)
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
# Store in cache
self.cache[vector_db.identifier] = VectorDBWithIndex(
@ -285,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
del self.cache[vector_db_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = self.cache.get(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
@ -298,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = self.cache.get(vector_db_id)
if index is None:

View file

@ -14,11 +14,6 @@ from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -15,11 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -15,11 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -17,13 +17,8 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -175,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
# Insert vector embeddings
embedding_data = [
(
(
chunk.chunk_id,
serialize_vector(emb.tolist()),
)
)
((chunk.chunk_id, serialize_vector(emb.tolist())))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
cur.executemany(
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data,
)
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
# Insert FTS content
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data],
)
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
# INSERT new entries
cur.executemany(
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data,
)
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
connection.commit()
@ -216,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Run batch insertion in a background thread
await asyncio.to_thread(_execute_all_batch_inserts)
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs vector-based search using a virtual table for vector similarity.
"""
@ -261,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
@ -410,17 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(
self,
config,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.vector_db_store = None
@ -433,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
@ -450,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
return [v.vector_db for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:

View file

@ -12,11 +12,6 @@ from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -12,24 +12,16 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=ids,
)
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
)
)
distances = results["distances"][0]
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -137,15 +118,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
def __init__(
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference,
models_apis: Api.models,
inference_api: Inference,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.models_api = models_apis
self.client = None
self.cache = {}
self.vector_db_store = None
@ -172,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
@ -194,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
@ -207,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)

View file

@ -13,12 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -14,13 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -74,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
# Add BM25 function for full-text search
bm25_function = Function(
@ -144,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
@ -167,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
@ -210,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
@ -308,7 +266,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -316,7 +273,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.client = None
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self.metadata_collection_name = "openai_vector_stores_metadata"
@ -355,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
@ -395,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -408,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps[Api.models], deps.get(Api.files, None))
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -16,26 +16,15 @@ from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig
@ -205,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
@ -317,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
def get_pgvector_search_function(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
@ -341,16 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.conn = None
self.cache = {}
self.vector_db_store = None
@ -407,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
)
await pgvector_index.initialize()
index = VectorDBWithIndex(
vector_db,
index=pgvector_index,
inference_api=self.inference_api,
)
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -424,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_chunks(query, params)

View file

@ -12,11 +12,6 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -30,11 +29,7 @@ from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -99,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=chunk_ids),
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
@ -133,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
@ -161,7 +150,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -169,7 +157,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self._qdrant_lock = asyncio.Lock()
@ -184,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
QdrantIndex(self.client, vector_db.identifier),
self.inference_api,
)
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
self.cache[vector_db.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -197,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
)
self.cache[vector_db.identifier] = index
@ -240,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -253,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:

View file

@ -12,11 +12,6 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -21,11 +21,7 @@ class WeaviateVectorIOConfig(BaseModel):
)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
**kwargs: Any,
) -> dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"weaviate_api_key": None,
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",

View file

@ -16,7 +16,6 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -24,9 +23,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion,
@ -48,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class WeaviateIndex(EmbeddingIndex):
def __init__(
self,
client: weaviate.WeaviateClient,
collection_name: str,
kvstore: KVStore | None = None,
):
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
self.kvstore = kvstore
@ -108,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
try:
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
)
except Exception as e:
log.error(f"Weaviate client vector search failed: {e}")
@ -153,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs BM25-based keyword search using Weaviate's built-in full-text search.
Args:
@ -175,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
# Perform BM25 keyword search on chunk_content field
try:
results = collection.query.bm25(
query=query_string,
limit=k,
return_metadata=wvc.query.MetadataQuery(score=True),
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
)
except Exception as e:
log.error(f"Weaviate client keyword search failed: {e}")
@ -274,23 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter(
OpenAIVectorStoreMixin,
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.client_cache = {}
self.cache = {}
self.vector_db_store = None
@ -301,10 +272,7 @@ class WeaviateVectorIOAdapter(
log.info("Using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
host=host,
port=port,
)
client = weaviate.connect_to_local(host=host, port=port)
else:
log.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
@ -334,15 +302,9 @@ class WeaviateVectorIOAdapter(
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(
client=client,
collection_name=vector_db.identifier,
kvstore=self.kvstore,
)
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
vector_db=vector_db, index=idx, inference_api=self.inference_api
)
# Load OpenAI vector stores metadata into cache
@ -354,10 +316,7 @@ class WeaviateVectorIOAdapter(
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
async def register_vector_db(self, vector_db: VectorDB) -> None:
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
# Create collection if it doesn't exist
@ -366,17 +325,12 @@ class WeaviateVectorIOAdapter(
name=sanitized_collection_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
],
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
self.inference_api,
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -412,12 +366,7 @@ class WeaviateVectorIOAdapter(
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
@ -425,10 +374,7 @@ class WeaviateVectorIOAdapter(
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:

View file

@ -17,7 +17,6 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.models import Model, Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -81,13 +80,14 @@ class OpenAIVectorStoreMixin(ABC):
# Implementing classes should call super().__init__() in their __init__ method
# to properly initialize the mixin attributes.
def __init__(
self, files_api: Files | None = None, kvstore: KVStore | None = None, models_api: Models | None = None
self,
files_api: Files | None = None,
kvstore: KVStore | None = None,
):
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.kvstore = kvstore
self.models_api = models_api
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
@ -393,21 +393,7 @@ class OpenAIVectorStoreMixin(ABC):
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if embedding_model is None:
result = await self._get_default_embedding_model_and_dimension()
if result is None:
raise ValueError(
"embedding_model is required in extra_body when creating a vector store. "
"No default embedding model could be determined automatically."
)
embedding_model, embedding_dimension = result
elif embedding_dimension is None:
# Embedding model was provided but dimension wasn't, look it up
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
if embedding_dimension is None:
raise ValueError(
f"Could not determine embedding dimension for model '{embedding_model}'. "
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
)
raise ValueError("embedding_model is required")
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
@ -474,85 +460,6 @@ class OpenAIVectorStoreMixin(ABC):
store_info = self.openai_vector_stores[vector_db_id]
return VectorStoreObject.model_validate(store_info)
async def _get_embedding_models(self) -> list[Model]:
"""Get list of embedding models from the models API."""
if not self.models_api:
return []
models_response = await self.models_api.list_models()
models_list = models_response.data if hasattr(models_response, "data") else models_response
embedding_models = []
for model in models_list:
if not isinstance(model, Model):
logger.warning(f"Non-Model object found in models list: {type(model)} - {model}")
continue
if model.model_type == "embedding":
embedding_models.append(model)
return embedding_models
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
"""Get embedding dimension for a specific model by looking it up in the models API.
Args:
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
Returns:
The embedding dimension for the model, or None if not found
"""
embedding_models = await self._get_embedding_models()
for model in embedding_models:
# Check for exact match first
if model.identifier == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
else:
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
return None
# Check for prefixed/unprefixed variations
# If model_id is unprefixed, check if it matches the resource_id
if model.provider_resource_id == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
return None
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
"""Get default embedding model from the models API.
Looks for embedding models marked with default_configured=True in metadata.
Returns None if no default embedding model is found.
Raises ValueError if multiple defaults are found.
"""
embedding_models = await self._get_embedding_models()
default_models = []
for model in embedding_models:
if model.metadata.get("default_configured") is True:
default_models.append(model.identifier)
if len(default_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_models}. "
"Only one embedding model can be marked as default."
)
if default_models:
model_id = default_models[0]
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
return model_id, embedding_dimension
logger.debug("No default embedding models found")
return None
async def openai_list_vector_stores(
self,
limit: int | None = 20,

View file

@ -317,3 +317,72 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
if p.is_relative_to(rp):
return False
return True
def get_vector_io_provider_ids(client):
"""Get all available vector_io provider IDs."""
providers = [p for p in client.providers.list() if p.api == "vector_io"]
return [p.provider_id for p in providers]
def vector_provider_wrapper(func):
"""Decorator to run a test against all available vector_io providers."""
import functools
import os
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Get the vector_io_provider_id from the test arguments
import inspect
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
vector_io_provider_id = bound_args.arguments.get("vector_io_provider_id")
if not vector_io_provider_id:
pytest.skip("No vector_io_provider_id provided")
# Get client_with_models to check available providers
client_with_models = bound_args.arguments.get("client_with_models")
if client_with_models:
available_providers = get_vector_io_provider_ids(client_with_models)
if vector_io_provider_id not in available_providers:
pytest.skip(f"Provider '{vector_io_provider_id}' not available. Available: {available_providers}")
return func(*args, **kwargs)
# For replay tests, only use providers that are available in ci-tests environment
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") == "replay":
all_providers = ["faiss", "sqlite-vec"]
else:
# For live tests, try all providers (they'll skip if not available)
all_providers = [
"faiss",
"sqlite-vec",
"milvus",
"chromadb",
"pgvector",
"weaviate",
"qdrant",
]
return pytest.mark.parametrize("vector_io_provider_id", all_providers)(wrapper)
@pytest.fixture
def vector_io_provider_id(request, client_with_models):
"""Fixture that provides a specific vector_io provider ID, skipping if not available."""
if hasattr(request, "param"):
requested_provider = request.param
available_providers = get_vector_io_provider_ids(client_with_models)
if requested_provider not in available_providers:
pytest.skip(f"Provider '{requested_provider}' not available. Available: {available_providers}")
return requested_provider
else:
provider_ids = get_vector_io_provider_ids(client_with_models)
if not provider_ids:
pytest.skip("No vector_io providers available")
return provider_ids[0]

View file

@ -21,6 +21,7 @@ from llama_stack_client import LlamaStackClient
from openai import OpenAI
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.stack import run_config_from_adhoc_config_spec
from llama_stack.env import get_env_or_fail
@ -236,6 +237,13 @@ def instantiate_llama_stack_client(session):
if "=" in config:
run_config = run_config_from_adhoc_config_spec(config)
# --stack-config bypasses template so need this to set default embedding model
if "vector_io" in config and "inference" in config:
run_config.vector_stores = VectorStoresConfig(
embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
)
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
yaml.dump(run_config.model_dump(mode="json"), f)

View file

@ -8,14 +8,15 @@ import time
from io import BytesIO
import pytest
from llama_stack_client import BadRequestError, NotFoundError
from llama_stack_client import BadRequestError
from openai import BadRequestError as OpenAIBadRequestError
from openai import NotFoundError as OpenAINotFoundError
from llama_stack.apis.vector_io import Chunk
from llama_stack.core.library_client import LlamaStackAsLibraryClient
from llama_stack.log import get_logger
from ..conftest import vector_provider_wrapper
logger = get_logger(name=__name__, category="vector_io")
@ -133,8 +134,9 @@ def compat_client_with_empty_stores(compat_client):
clear_files()
@vector_provider_wrapper
def test_openai_create_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test creating a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -146,6 +148,7 @@ def test_openai_create_vector_store(
metadata={"purpose": "testing", "environment": "integration"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -159,14 +162,18 @@ def test_openai_create_vector_store(
assert hasattr(vector_store, "created_at")
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models):
@vector_provider_wrapper
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models, vector_io_provider_id):
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
vector_store = compat_client_with_empty_stores.vector_stores.create()
vector_store = compat_client_with_empty_stores.vector_stores.create(
extra_body={"provider_id": vector_io_provider_id}
)
assert vector_store.id
@vector_provider_wrapper
def test_openai_list_vector_stores(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test listing vector stores using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -179,6 +186,7 @@ def test_openai_list_vector_stores(
metadata={"type": "test"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
store2 = client.vector_stores.create(
@ -186,6 +194,7 @@ def test_openai_list_vector_stores(
metadata={"type": "test"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -206,8 +215,9 @@ def test_openai_list_vector_stores(
assert len(limited_response.data) == 1
@vector_provider_wrapper
def test_openai_retrieve_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test retrieving a specific vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -220,6 +230,7 @@ def test_openai_retrieve_vector_store(
metadata={"purpose": "retrieval_test"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -233,8 +244,9 @@ def test_openai_retrieve_vector_store(
assert retrieved_store.object == "vector_store"
@vector_provider_wrapper
def test_openai_update_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test modifying a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -247,6 +259,7 @@ def test_openai_update_vector_store(
metadata={"version": "1.0"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
time.sleep(1)
@ -264,8 +277,9 @@ def test_openai_update_vector_store(
assert modified_store.last_active_at > created_store.last_active_at
@vector_provider_wrapper
def test_openai_delete_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test deleting a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -278,6 +292,7 @@ def test_openai_delete_vector_store(
metadata={"purpose": "deletion_test"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -294,8 +309,9 @@ def test_openai_delete_vector_store(
client.vector_stores.retrieve(vector_store_id=created_store.id)
@vector_provider_wrapper
def test_openai_vector_store_search_empty(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test searching an empty vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -308,6 +324,7 @@ def test_openai_vector_store_search_empty(
metadata={"purpose": "search_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -323,8 +340,14 @@ def test_openai_vector_store_search_empty(
assert search_response.has_more is False
@vector_provider_wrapper
def test_openai_vector_store_with_chunks(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
compat_client_with_empty_stores,
client_with_models,
sample_chunks,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -338,6 +361,7 @@ def test_openai_vector_store_with_chunks(
metadata={"purpose": "chunks_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -380,6 +404,7 @@ def test_openai_vector_store_with_chunks(
("What inspires neural networks?", "doc4", "ai"),
],
)
@vector_provider_wrapper
def test_openai_vector_store_search_relevance(
compat_client_with_empty_stores,
client_with_models,
@ -387,6 +412,7 @@ def test_openai_vector_store_search_relevance(
test_case,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
"""Test that OpenAI vector store search returns relevant results for different queries."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -402,6 +428,7 @@ def test_openai_vector_store_search_relevance(
metadata={"purpose": "relevance_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -430,8 +457,14 @@ def test_openai_vector_store_search_relevance(
assert top_result.score > 0
@vector_provider_wrapper
def test_openai_vector_store_search_with_ranking_options(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
compat_client_with_empty_stores,
client_with_models,
sample_chunks,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
"""Test OpenAI vector store search with ranking options."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -445,6 +478,7 @@ def test_openai_vector_store_search_with_ranking_options(
metadata={"purpose": "ranking_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -483,8 +517,14 @@ def test_openai_vector_store_search_with_ranking_options(
assert result.score >= threshold
@vector_provider_wrapper
def test_openai_vector_store_search_with_high_score_filter(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
compat_client_with_empty_stores,
client_with_models,
sample_chunks,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -498,6 +538,7 @@ def test_openai_vector_store_search_with_high_score_filter(
metadata={"purpose": "high_score_filtering"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -542,8 +583,14 @@ def test_openai_vector_store_search_with_high_score_filter(
assert "python" in top_content.lower() or "programming" in top_content.lower()
@vector_provider_wrapper
def test_openai_vector_store_search_with_max_num_results(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
compat_client_with_empty_stores,
client_with_models,
sample_chunks,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
"""Test OpenAI vector store search with max_num_results."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -557,6 +604,7 @@ def test_openai_vector_store_search_with_max_num_results(
metadata={"purpose": "max_num_results_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -577,8 +625,9 @@ def test_openai_vector_store_search_with_max_num_results(
assert len(search_response.data) == 2
@vector_provider_wrapper
def test_openai_vector_store_attach_file(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store attach file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -591,6 +640,7 @@ def test_openai_vector_store_attach_file(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -637,8 +687,9 @@ def test_openai_vector_store_attach_file(
assert "foobazbar" in top_content.lower()
@vector_provider_wrapper
def test_openai_vector_store_attach_files_on_creation(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store attach files on creation."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -668,6 +719,7 @@ def test_openai_vector_store_attach_files_on_creation(
file_ids=file_ids,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -700,8 +752,9 @@ def test_openai_vector_store_attach_files_on_creation(
assert updated_vector_store.file_counts.failed == 0
@vector_provider_wrapper
def test_openai_vector_store_list_files(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store list files."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -714,6 +767,7 @@ def test_openai_vector_store_list_files(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -773,8 +827,9 @@ def test_openai_vector_store_list_files(
assert updated_vector_store.file_counts.in_progress == 0
@vector_provider_wrapper
def test_openai_vector_store_list_files_invalid_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store list files with invalid vector store ID."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -783,14 +838,15 @@ def test_openai_vector_store_list_files_invalid_vector_store(
if isinstance(compat_client, LlamaStackAsLibraryClient):
errors = ValueError
else:
errors = (NotFoundError, OpenAINotFoundError)
errors = (BadRequestError, OpenAIBadRequestError)
with pytest.raises(errors):
compat_client.vector_stores.files.list(vector_store_id="abc123")
@vector_provider_wrapper
def test_openai_vector_store_retrieve_file_contents(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store retrieve file contents."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -803,6 +859,7 @@ def test_openai_vector_store_retrieve_file_contents(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -848,8 +905,9 @@ def test_openai_vector_store_retrieve_file_contents(
assert file_contents.attributes == attributes
@vector_provider_wrapper
def test_openai_vector_store_delete_file(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store delete file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -862,6 +920,7 @@ def test_openai_vector_store_delete_file(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -912,8 +971,9 @@ def test_openai_vector_store_delete_file(
assert updated_vector_store.file_counts.in_progress == 0
@vector_provider_wrapper
def test_openai_vector_store_delete_file_removes_from_vector_store(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store delete file removes from vector store."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -926,6 +986,7 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -962,8 +1023,9 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
assert not search_response.data
@vector_provider_wrapper
def test_openai_vector_store_update_file(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test OpenAI vector store update file."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -976,6 +1038,7 @@ def test_openai_vector_store_update_file(
name="test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1017,8 +1080,9 @@ def test_openai_vector_store_update_file(
assert retrieved_file.attributes["foo"] == "baz"
@vector_provider_wrapper
def test_create_vector_store_files_duplicate_vector_store_name(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""
This test confirms that client.vector_stores.create() creates a unique ID
@ -1044,6 +1108,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
name="test_store_with_files",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
assert vector_store.file_counts.completed == 0
@ -1056,6 +1121,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
name="test_store_with_files",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1086,8 +1152,15 @@ def test_create_vector_store_files_duplicate_vector_store_name(
@pytest.mark.parametrize("search_mode", ["vector", "keyword", "hybrid"])
@vector_provider_wrapper
def test_openai_vector_store_search_modes(
llama_stack_client, client_with_models, sample_chunks, search_mode, embedding_model_id, embedding_dimension
llama_stack_client,
client_with_models,
sample_chunks,
search_mode,
embedding_model_id,
embedding_dimension,
vector_io_provider_id,
):
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
@ -1097,6 +1170,7 @@ def test_openai_vector_store_search_modes(
metadata={"purpose": "search_mode_testing"},
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1115,8 +1189,9 @@ def test_openai_vector_store_search_modes(
assert search_response is not None
@vector_provider_wrapper
def test_openai_vector_store_file_batch_create_and_retrieve(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test creating and retrieving a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1128,6 +1203,7 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
name="batch_test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1178,8 +1254,9 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
assert retrieved_batch.status == "completed" # Should be completed after processing
@vector_provider_wrapper
def test_openai_vector_store_file_batch_list_files(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test listing files in a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1191,6 +1268,7 @@ def test_openai_vector_store_file_batch_list_files(
name="batch_list_test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1271,8 +1349,9 @@ def test_openai_vector_store_file_batch_list_files(
assert first_page_ids.isdisjoint(second_page_ids)
@vector_provider_wrapper
def test_openai_vector_store_file_batch_cancel(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test cancelling a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1284,6 +1363,7 @@ def test_openai_vector_store_file_batch_cancel(
name="batch_cancel_test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1326,8 +1406,9 @@ def test_openai_vector_store_file_batch_cancel(
assert final_batch.status in ["completed", "cancelled"]
@vector_provider_wrapper
def test_openai_vector_store_file_batch_retrieve_contents(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test retrieving file contents after file batch processing."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1339,6 +1420,7 @@ def test_openai_vector_store_file_batch_retrieve_contents(
name="batch_contents_test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1399,8 +1481,9 @@ def test_openai_vector_store_file_batch_retrieve_contents(
assert file_data[i][1].decode("utf-8") in content_text
@vector_provider_wrapper
def test_openai_vector_store_file_batch_error_handling(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test error handling for file batch operations."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1412,6 +1495,7 @@ def test_openai_vector_store_file_batch_error_handling(
name="batch_error_test_store",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -1443,11 +1527,11 @@ def test_openai_vector_store_file_batch_error_handling(
batch_id="non_existent_batch_id",
)
# Test operations on non-existent vector store (returns NotFoundError)
# Test operations on non-existent vector store (returns BadRequestError)
if isinstance(compat_client, LlamaStackAsLibraryClient):
vector_store_errors = ValueError
else:
vector_store_errors = (NotFoundError, OpenAINotFoundError)
vector_store_errors = (BadRequestError, OpenAIBadRequestError)
with pytest.raises(vector_store_errors): # Should raise an error for non-existent vector store
compat_client.vector_stores.file_batches.create(
@ -1456,8 +1540,9 @@ def test_openai_vector_store_file_batch_error_handling(
)
@vector_provider_wrapper
def test_openai_vector_store_embedding_config_from_metadata(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that embedding configuration works from metadata source."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
@ -1471,6 +1556,9 @@ def test_openai_vector_store_embedding_config_from_metadata(
"embedding_dimension": str(embedding_dimension),
"test_source": "metadata",
},
extra_body={
"provider_id": vector_io_provider_id,
},
)
assert vector_store_metadata is not None
@ -1489,6 +1577,7 @@ def test_openai_vector_store_embedding_config_from_metadata(
extra_body={
"embedding_model": embedding_model_id,
"embedding_dimension": int(embedding_dimension), # Ensure same type/value
"provider_id": vector_io_provider_id,
},
)

View file

@ -8,6 +8,8 @@ import pytest
from llama_stack.apis.vector_io import Chunk
from ..conftest import vector_provider_wrapper
@pytest.fixture(scope="session")
def sample_chunks():
@ -46,12 +48,13 @@ def client_with_empty_registry(client_with_models):
clear_registry()
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
@vector_provider_wrapper
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
vector_db_name = "test_vector_db"
create_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -65,12 +68,13 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe
assert response.id.startswith("vs_")
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
@vector_provider_wrapper
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
vector_db_name = "test_vector_db"
response = client_with_empty_registry.vector_stores.create(
name=vector_db_name,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -100,12 +104,15 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
("How does machine learning improve over time?", "doc2"),
],
)
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
@vector_provider_wrapper
def test_insert_chunks(
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
):
vector_db_name = "test_vector_db"
create_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -135,7 +142,10 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
@vector_provider_wrapper
def test_insert_chunks_with_precomputed_embeddings(
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
vector_io_provider_params_dict = {
"inline::milvus": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0},
@ -145,7 +155,7 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
register_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -181,8 +191,9 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
# expect this test to fail
@vector_provider_wrapper
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
client_with_empty_registry, embedding_model_id, embedding_dimension
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
vector_io_provider_params_dict = {
"inline::milvus": {"score_threshold": 0.0},
@ -194,6 +205,7 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
name=vector_db_name,
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
},
)
@ -226,33 +238,44 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
assert response.chunks[0].metadata["source"] == "precomputed"
def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id):
@vector_provider_wrapper
def test_auto_extract_embedding_dimension(
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
# This test specifically tests embedding model override, so we keep embedding_model
vs = client_with_empty_registry.vector_stores.create(
name="test_auto_extract", extra_body={"embedding_model": embedding_model_id}
name="test_auto_extract",
extra_body={"embedding_model": embedding_model_id, "provider_id": vector_io_provider_id},
)
assert vs.id is not None
def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id):
@vector_provider_wrapper
def test_provider_auto_selection_single_provider(
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
if len(providers) != 1:
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
vs = client_with_empty_registry.vector_stores.create(
name="test_auto_provider", extra_body={"embedding_model": embedding_model_id}
)
# Test that when only one provider is available, it's auto-selected (no provider_id needed)
vs = client_with_empty_registry.vector_stores.create(name="test_auto_provider")
assert vs.id is not None
def test_provider_id_override(client_with_empty_registry, embedding_model_id):
@vector_provider_wrapper
def test_provider_id_override(
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
if len(providers) != 1:
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
provider_id = providers[0].provider_id
# Test explicit provider_id specification (using default embedding model)
vs = client_with_empty_registry.vector_stores.create(
name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id}
name="test_provider_override", extra_body={"provider_id": provider_id}
)
assert vs.id is not None
assert vs.metadata.get("provider_id") == provider_id

View file

@ -4,90 +4,64 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for Stack validation functions.
"""
"""Unit tests for Stack validation functions."""
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.stack import validate_default_embedding_model
from llama_stack.apis.models import ListModelsResponse, Model, ModelType
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.stack import validate_vector_stores_config
from llama_stack.providers.datatypes import Api
class TestStackValidation:
"""Test Stack validation functions."""
class TestVectorStoresValidation:
async def test_validate_missing_model(self):
"""Test validation fails when model not found."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="missing",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(data=[])
@pytest.mark.parametrize(
"models,should_raise",
[
([], False), # No models
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
)
],
False,
), # Single default
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="emb2",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb2",
),
],
True,
), # Multiple defaults
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="llm1",
model_type=ModelType.llm,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="llm1",
),
],
False,
), # Ignores non-embedding
],
)
async def test_validate_default_embedding_model(self, models, should_raise):
"""Test validation with various model configurations."""
mock_models_impl = AsyncMock()
mock_models_impl.list_models.return_value = models
impls = {Api.models: mock_models_impl}
with pytest.raises(ValueError, match="not found"):
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
if should_raise:
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await validate_default_embedding_model(impls)
else:
await validate_default_embedding_model(impls)
async def test_validate_success(self):
"""Test validation passes with valid model."""
run_config = StackRunConfig(
image_name="test",
providers={},
storage=StorageConfig(backends={}, stores={}),
vector_stores=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="p",
model_id="valid",
),
),
)
mock_models = AsyncMock()
mock_models.list_models.return_value = ListModelsResponse(
data=[
Model(
identifier="p/valid", # Must match provider_id/model_id format
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768},
provider_id="p",
provider_resource_id="valid",
)
]
)
async def test_validate_default_embedding_model_no_models_api(self):
"""Test validation when models API is not available."""
await validate_default_embedding_model({})
await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})

View file

@ -146,7 +146,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
@ -185,7 +184,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(

View file

@ -11,7 +11,6 @@ import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus
@ -76,12 +75,6 @@ def mock_files_api():
return mock_api
@pytest.fixture
def mock_models_api():
mock_api = MagicMock(spec=Models)
return mock_api
@pytest.fixture
def faiss_config():
config = MagicMock(spec=FaissVectorIOConfig)
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1]
async def test_health_success(mock_models_api):
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
async def test_health_failure(mock_models_api):
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
# Calling the health method directly
response = await adapter.health()

View file

@ -6,13 +6,12 @@
import json
import time
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -996,96 +995,6 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
assert batch.file_counts.in_progress == 8
async def test_get_default_embedding_model_success(vector_io_adapter):
"""Test successful default embedding model detection."""
# Mock models API with a default model
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="nomic-embed-text-v1.5",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
result = await vector_io_adapter._get_default_embedding_model_and_dimension()
assert result is not None
model_id, dimension = result
assert model_id == "nomic-embed-text-v1.5"
assert dimension == 768
async def test_get_default_embedding_model_multiple_defaults_error(vector_io_adapter):
"""Test error when multiple models are marked as default."""
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="model1",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 768, "default_configured": True},
),
Model(
identifier="model2",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
),
]
)
)
vector_io_adapter.models_api = mock_models_api
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await vector_io_adapter._get_default_embedding_model_and_dimension()
async def test_openai_create_vector_store_uses_default_model(vector_io_adapter):
"""Test that vector store creation uses default embedding model when none specified."""
# Mock models API and dependencies
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="default-model",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
vector_io_adapter.register_vector_db = AsyncMock()
vector_io_adapter.__provider_id__ = "test-provider"
# Create vector store without specifying embedding model
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test-store")
result = await vector_io_adapter.openai_create_vector_store(params)
# Verify the vector store was created with default model
assert result.name == "test-store"
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
assert call_args.embedding_model == "default-model"
assert call_args.embedding_dimension == 512
async def test_embedding_config_from_metadata(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from metadata."""
@ -1253,5 +1162,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
# Test with no embedding model provided
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)