mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-06 18:40:57 +00:00
chore: Updating how default embedding model is set in stack (#3818)
# What does this PR do?
Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).
New config is simply (default for Starter distro):
```yaml
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
```
## Test Plan
CI and Unit tests.
---------
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2c43285e22
commit
48581bf651
48 changed files with 973 additions and 818 deletions
|
|
@ -121,6 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_dbs = "vector_dbs" # only used for routing
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -59,3 +59,35 @@ class ListVectorDBsResponse(BaseModel):
|
|||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VectorDBs(Protocol):
|
||||
"""Internal protocol for vector_dbs routing - no public API endpoints."""
|
||||
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""Internal method to list vector databases."""
|
||||
...
|
||||
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB:
|
||||
"""Internal method to get a vector database by ID."""
|
||||
...
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Internal method to register a vector database."""
|
||||
...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Internal method to unregister a vector database."""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -354,6 +354,26 @@ class AuthenticationRequiredError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class QualifiedModel(BaseModel):
|
||||
"""A qualified model identifier, consisting of a provider ID and a model ID."""
|
||||
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
|
||||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
||||
default_provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||
)
|
||||
default_embedding_model: QualifiedModel | None = Field(
|
||||
default=None,
|
||||
description="Default embedding model configuration for vector stores.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
|
@ -499,6 +519,11 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
vector_stores: VectorStoresConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for vector stores, including default embedding model",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
|
|||
|
|
@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_dbs,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
|
|
@ -81,6 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"models": ModelsRoutingTable,
|
||||
|
|
@ -37,6 +38,7 @@ async def get_routing_table_impl(
|
|||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
|
|
@ -91,6 +93,9 @@ async def get_auto_router_impl(
|
|||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
|
|
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
self.vector_stores_config = vector_stores_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
|
|
@ -122,6 +125,17 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# Use default embedding model if not specified
|
||||
if (
|
||||
embedding_model is None
|
||||
and self.vector_stores_config
|
||||
and self.vector_stores_config.default_embedding_model is not None
|
||||
):
|
||||
# Construct the full model ID with provider prefix
|
||||
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
|
||||
model_id = self.vector_stores_config.default_embedding_model.model_id
|
||||
embedding_model = f"{embedding_provider_id}/{model_id}"
|
||||
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
|
|
@ -132,11 +146,24 @@ class VectorIORouter(VectorIO):
|
|||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
# Use default configured provider
|
||||
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
|
||||
default_provider = self.vector_stores_config.default_provider_id
|
||||
if default_provider in available_providers:
|
||||
provider_id = default_provider
|
||||
logger.debug(f"Using configured default vector store provider: {provider_id}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Configured default vector store provider '{default_provider}' not found. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
@ -243,8 +270,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
|||
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
||||
# Removed VectorDBs import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_db operations.
|
||||
|
||||
Does not inherit from VectorDBs to avoid exposing public API endpoints.
|
||||
Only provides internal routing functionality for VectorIORouter.
|
||||
"""
|
||||
|
||||
# Internal methods only - no public API exposure
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> Any:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
try:
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
except KeyError:
|
||||
available_providers = list(self.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
||||
) from None
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(request)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def unregister_vector_db(self, vector_store_id: str) -> None:
|
||||
"""Remove the vector store from the routing table registry."""
|
||||
try:
|
||||
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
|
||||
if vector_db_obj:
|
||||
await self.unregister_object(vector_db_obj)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the operation
|
||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
@ -35,7 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
|
|
@ -108,30 +108,6 @@ REGISTRY_REFRESH_TASK = None
|
|||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def validate_default_embedding_model(impls: dict[Api, Any]):
|
||||
"""Validate that at most one embedding model is marked as default."""
|
||||
if Api.models not in impls:
|
||||
return
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = response.data if hasattr(response, "data") else response
|
||||
|
||||
default_embedding_models = []
|
||||
for model in models_list:
|
||||
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
|
||||
default_embedding_models.append(model.identifier)
|
||||
|
||||
if len(default_embedding_models) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
|
||||
"Only one embedding model can be marked as default."
|
||||
)
|
||||
|
||||
if default_embedding_models:
|
||||
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
|
@ -162,7 +138,41 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
await validate_default_embedding_model(impls)
|
||||
|
||||
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
|
||||
"""Validate vector stores configuration."""
|
||||
if vector_stores_config is None:
|
||||
return
|
||||
|
||||
default_embedding_model = vector_stores_config.default_embedding_model
|
||||
if default_embedding_model is None:
|
||||
return
|
||||
|
||||
provider_id = default_embedding_model.provider_id
|
||||
model_id = default_embedding_model.model_id
|
||||
default_model_id = f"{provider_id}/{model_id}"
|
||||
|
||||
if Api.models not in impls:
|
||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
|
||||
|
||||
default_model = models_list.get(default_model_id)
|
||||
if default_model is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
|
||||
|
||||
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
|
||||
|
||||
try:
|
||||
int(embedding_dimension)
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
|
||||
|
||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
|
|
@ -400,8 +410,8 @@ class Stack:
|
|||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -253,3 +268,8 @@ server:
|
|||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -256,3 +271,8 @@ server:
|
|||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -253,3 +268,8 @@ server:
|
|||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ from llama_stack.core.datatypes import (
|
|||
BuildProvider,
|
||||
Provider,
|
||||
ProviderSpec,
|
||||
QualifiedModel,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
|
|
@ -31,6 +33,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
|
|
@ -113,6 +117,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="remote::qdrant"),
|
||||
BuildProvider(provider_type="remote::weaviate"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [
|
||||
|
|
@ -221,12 +227,35 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.QDRANT_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from llama_stack.core.datatypes import (
|
|||
ShieldInput,
|
||||
TelemetryConfig,
|
||||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
|
|
@ -186,6 +187,7 @@ class RunConfigSettings(BaseModel):
|
|||
default_tool_groups: list[ToolGroupInput] | None = None
|
||||
default_datasets: list[DatasetInput] | None = None
|
||||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
vector_stores_config: VectorStoresConfig | None = None
|
||||
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
|
||||
storage_backends: dict[str, Any] | None = None
|
||||
storage_stores: dict[str, Any] | None = None
|
||||
|
|
@ -263,7 +265,7 @@ class RunConfigSettings(BaseModel):
|
|||
)
|
||||
|
||||
# Return a dict that matches StackRunConfig structure
|
||||
return {
|
||||
config = {
|
||||
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
"image_name": name,
|
||||
"container_image": container_image,
|
||||
|
|
@ -283,6 +285,11 @@ class RunConfigSettings(BaseModel):
|
|||
"telemetry": self.telemetry.model_dump(exclude_none=True) if self.telemetry else None,
|
||||
}
|
||||
|
||||
if self.vector_stores_config:
|
||||
config["vector_stores"] = self.vector_stores_config.model_dump(exclude_none=True)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class DistributionTemplate(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ class SentenceTransformersInferenceImpl(
|
|||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"default_configured": True,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -12,15 +12,8 @@ from .config import ChromaVectorIOConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -16,11 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -17,27 +17,14 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
|
@ -155,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -175,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
|
@ -200,17 +177,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: FaissVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -252,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
|
@ -285,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
|
|
@ -298,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
|||
|
|
@ -14,11 +14,6 @@ from .config import MilvusVectorIOConfig
|
|||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -15,11 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
|||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -15,11 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -17,13 +17,8 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -175,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
(
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||
|
||||
connection.commit()
|
||||
|
||||
|
|
@ -216,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
|
|
@ -261,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
|
|
@ -410,17 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.vector_db_store = None
|
||||
|
||||
|
|
@ -433,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
|
|
@ -450,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
|
|
|
|||
|
|
@ -12,11 +12,6 @@ from .config import ChromaVectorIOConfig
|
|||
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -12,24 +12,16 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
@ -68,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
|
|
@ -108,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
@ -137,15 +118,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
def __init__(
|
||||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
models_apis: Api.models,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_apis
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -172,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
|
@ -194,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
|
@ -207,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
|
|||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -14,13 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
|
|
@ -74,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
|
@ -144,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
|
@ -167,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
|
|
@ -210,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
|
|
@ -308,7 +266,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
|
@ -316,7 +273,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_db_store = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
@ -355,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
|
|
@ -395,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -408,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
|
|||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorVectorIOAdapter
|
||||
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps[Api.models], deps.get(Api.files, None))
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -16,26 +16,15 @@ from pydantic import BaseModel, TypeAdapter
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
|
@ -205,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
|
|
@ -317,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
|
@ -341,16 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -407,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
|
@ -424,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
|
|
|||
|
|
@ -12,11 +12,6 @@ from .config import QdrantVectorIOConfig
|
|||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -30,11 +29,7 @@ from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
|||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
|
|
@ -99,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
|
|
@ -133,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
|
|
@ -161,7 +150,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
|
@ -169,7 +157,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_db_store = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -184,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
QdrantIndex(self.client, vector_db.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
|
@ -197,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(self.client, vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
|
@ -240,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -253,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -12,11 +12,6 @@ from .config import WeaviateVectorIOConfig
|
|||
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateVectorIOAdapter
|
||||
|
||||
impl = WeaviateVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
)
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -21,11 +21,7 @@ class WeaviateVectorIOConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"weaviate_api_key": None,
|
||||
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
|
@ -24,9 +23,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
|
|
@ -48,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -108,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
|
|
@ -153,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
|
|
@ -175,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
|
|
@ -274,23 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
OpenAIVectorStoreMixin,
|
||||
VectorIO,
|
||||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -301,10 +272,7 @@ class WeaviateVectorIOAdapter(
|
|||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
client = weaviate.connect_to_local(host=host, port=port)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
|
|
@ -334,15 +302,9 @@ class WeaviateVectorIOAdapter(
|
|||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(
|
||||
client=client,
|
||||
collection_name=vector_db.identifier,
|
||||
kvstore=self.kvstore,
|
||||
)
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=idx, inference_api=self.inference_api
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
|
@ -354,10 +316,7 @@ class WeaviateVectorIOAdapter(
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
# Create collection if it doesn't exist
|
||||
|
|
@ -366,17 +325,12 @@ class WeaviateVectorIOAdapter(
|
|||
name=sanitized_collection_name,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
|
@ -412,12 +366,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -425,10 +374,7 @@ class WeaviateVectorIOAdapter(
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||
from llama_stack.apis.models import Model, Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -81,13 +80,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# Implementing classes should call super().__init__() in their __init__ method
|
||||
# to properly initialize the mixin attributes.
|
||||
def __init__(
|
||||
self, files_api: Files | None = None, kvstore: KVStore | None = None, models_api: Models | None = None
|
||||
self,
|
||||
files_api: Files | None = None,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore = kvstore
|
||||
self.models_api = models_api
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
|
|
@ -393,21 +393,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if embedding_model is None:
|
||||
result = await self._get_default_embedding_model_and_dimension()
|
||||
if result is None:
|
||||
raise ValueError(
|
||||
"embedding_model is required in extra_body when creating a vector store. "
|
||||
"No default embedding model could be determined automatically."
|
||||
)
|
||||
embedding_model, embedding_dimension = result
|
||||
elif embedding_dimension is None:
|
||||
# Embedding model was provided but dimension wasn't, look it up
|
||||
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(
|
||||
f"Could not determine embedding dimension for model '{embedding_model}'. "
|
||||
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
|
||||
)
|
||||
raise ValueError("embedding_model is required")
|
||||
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
|
@ -474,85 +460,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def _get_embedding_models(self) -> list[Model]:
|
||||
"""Get list of embedding models from the models API."""
|
||||
if not self.models_api:
|
||||
return []
|
||||
|
||||
models_response = await self.models_api.list_models()
|
||||
models_list = models_response.data if hasattr(models_response, "data") else models_response
|
||||
|
||||
embedding_models = []
|
||||
for model in models_list:
|
||||
if not isinstance(model, Model):
|
||||
logger.warning(f"Non-Model object found in models list: {type(model)} - {model}")
|
||||
continue
|
||||
if model.model_type == "embedding":
|
||||
embedding_models.append(model)
|
||||
|
||||
return embedding_models
|
||||
|
||||
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
|
||||
"""Get embedding dimension for a specific model by looking it up in the models API.
|
||||
|
||||
Args:
|
||||
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
|
||||
|
||||
Returns:
|
||||
The embedding dimension for the model, or None if not found
|
||||
"""
|
||||
embedding_models = await self._get_embedding_models()
|
||||
|
||||
for model in embedding_models:
|
||||
# Check for exact match first
|
||||
if model.identifier == model_id:
|
||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is not None:
|
||||
return int(embedding_dimension)
|
||||
else:
|
||||
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
|
||||
return None
|
||||
|
||||
# Check for prefixed/unprefixed variations
|
||||
# If model_id is unprefixed, check if it matches the resource_id
|
||||
if model.provider_resource_id == model_id:
|
||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is not None:
|
||||
return int(embedding_dimension)
|
||||
|
||||
return None
|
||||
|
||||
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
|
||||
"""Get default embedding model from the models API.
|
||||
|
||||
Looks for embedding models marked with default_configured=True in metadata.
|
||||
Returns None if no default embedding model is found.
|
||||
Raises ValueError if multiple defaults are found.
|
||||
"""
|
||||
embedding_models = await self._get_embedding_models()
|
||||
|
||||
default_models = []
|
||||
for model in embedding_models:
|
||||
if model.metadata.get("default_configured") is True:
|
||||
default_models.append(model.identifier)
|
||||
|
||||
if len(default_models) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple embedding models marked as default_configured=True: {default_models}. "
|
||||
"Only one embedding model can be marked as default."
|
||||
)
|
||||
|
||||
if default_models:
|
||||
model_id = default_models[0]
|
||||
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
|
||||
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
|
||||
return model_id, embedding_dimension
|
||||
|
||||
logger.debug("No default embedding models found")
|
||||
return None
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue