mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
update resolver to only pass vector_stores section of run config
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> Using Router only from VectorDBs Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> removing model_api from vector store providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix test Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating integration tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
24a1430c8b
commit
accc4c437e
46 changed files with 397 additions and 702 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
|||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
||||
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
||||
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
|
|
@ -168,7 +168,7 @@ jobs:
|
|||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
run: |
|
||||
uv run --no-sync \
|
||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
||||
tests/integration/vector_io
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
|
|
|
|||
|
|
@ -50,6 +50,85 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||
# Parse the single-provider argument using the same logic as --providers
|
||||
provider_filters: dict[str, str] = {}
|
||||
for api_provider in single_provider_arg.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
provider_filters[api] = provider_type
|
||||
|
||||
# Create a copy of the build config to modify
|
||||
filtered_build_config = BuildConfig(
|
||||
image_type=build_config.image_type,
|
||||
image_name=build_config.image_name,
|
||||
external_providers_dir=build_config.external_providers_dir,
|
||||
external_apis_dir=build_config.external_apis_dir,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={},
|
||||
description=build_config.distribution_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
# Copy all providers, but filter the specified APIs
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in provider_filters:
|
||||
target_provider_type = provider_filters[api]
|
||||
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||
if not filtered_providers:
|
||||
cprint(
|
||||
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||
else:
|
||||
# Keep all providers for unfiltered APIs
|
||||
filtered_build_config.distribution_spec.providers[api] = providers
|
||||
|
||||
return filtered_build_config
|
||||
|
||||
|
||||
def _generate_filtered_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
distro_name: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||
and filtering the providers according to the build_config.
|
||||
"""
|
||||
# Load the original distribution's run.yaml
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
with open(path) as f:
|
||||
original_config = yaml.safe_load(f)
|
||||
|
||||
# Apply provider filtering to the loaded config
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in original_config.get("providers", {}):
|
||||
# Filter this API to only include the providers from build_config
|
||||
provider_types = {p.provider_type for p in providers}
|
||||
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||
original_config["providers"][api] = filtered_providers
|
||||
|
||||
# Write the filtered config
|
||||
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||
with open(run_config_file, "w") as f:
|
||||
yaml.dump(original_config, f, sort_keys=False)
|
||||
|
||||
return run_config_file
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
|
@ -93,6 +172,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
build_config = available_distros[distro_name]
|
||||
|
||||
# Apply single-provider filtering if specified
|
||||
if args.single_provider:
|
||||
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
|
|
@ -245,6 +329,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
distro_name=distro_name,
|
||||
is_filtered=bool(args.single_provider),
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
|
|
@ -363,6 +448,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
distro_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
is_filtered: bool = False,
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
@ -435,12 +521,19 @@ def _run_stack_build_command_from_build_config(
|
|||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if distro_name:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
# If single-provider filtering was applied, generate a filtered run config
|
||||
# Otherwise, copy run.yaml from distribution as before
|
||||
if is_filtered:
|
||||
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||
distro_path = run_config_file # Use the generated file as the distro_path
|
||||
else:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found.
|
|||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--single-provider",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
|
|
|
|||
|
|
@ -409,10 +409,6 @@ async def instantiate_provider(
|
|||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||
args.append(run_config.telemetry.enabled)
|
||||
|
||||
# vector_io providers need access to run_config.vector_stores
|
||||
if provider_spec.api == Api.vector_io and "run_config" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(run_config)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_id__ = provider.provider_id
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ async def get_auto_router_impl(
|
|||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
if api == Api.vector_io and run_config.vector_stores:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
|
|
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
self.vector_stores_config = vector_stores_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
|
|
@ -122,6 +125,10 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
if embedding_model is None and self.vector_stores_config is not None:
|
||||
embedding_model = self.vector_stores_config.default_embedding_model_id
|
||||
logger.debug(f"Using default embedding model: {embedding_model}")
|
||||
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -240,7 +255,7 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
vector_stores:
|
||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
|
|
@ -114,6 +116,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="remote::qdrant"),
|
||||
BuildProvider(provider_type="remote::weaviate"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [
|
||||
|
|
@ -222,6 +226,22 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.QDRANT_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -6,29 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: ChromaVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="chroma_inline_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,29 +6,16 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: FaissVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = FaissVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -22,8 +19,5 @@ class FaissVectorIOConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="faiss_store.db",
|
||||
)
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(__distro_dir__=__distro_dir__, db_name="faiss_store.db")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,28 +17,14 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
|
@ -156,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -176,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
|
@ -201,19 +177,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: FaissVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -255,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
|
@ -288,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
|
|
@ -301,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
|||
|
|
@ -6,27 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MilvusVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps.get(Api.models),
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -26,7 +23,6 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="milvus_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,28 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: QdrantVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -9,10 +9,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,28 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: SQLiteVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
|
||||
class SQLiteVectorIOConfig(BaseModel):
|
||||
|
|
@ -23,7 +20,6 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="sqlite_vec_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,14 +17,8 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -176,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
(
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||
|
||||
connection.commit()
|
||||
|
||||
|
|
@ -217,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
|
|
@ -262,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
|
|
@ -411,19 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.vector_db_store = None
|
||||
|
||||
|
|
@ -436,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
|
|
@ -453,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
|
|
|
|||
|
|
@ -4,27 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .chroma import ChromaVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -13,25 +13,15 @@ from numpy.typing import NDArray
|
|||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
@ -70,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
|
|
@ -110,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
async def delete(self):
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
@ -140,16 +119,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_apis: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_apis
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -176,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
|
@ -198,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
|
@ -211,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"url": url,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_remote_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="chroma_remote_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,28 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="milvus_remote_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,14 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
|
|
@ -75,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
|
@ -145,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
|
@ -168,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
|
|
@ -211,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
|
|
@ -309,17 +266,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models | None,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.vector_db_store = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
|
|
@ -358,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
|
|
@ -398,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -411,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -4,26 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
impl = PGVectorVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -42,7 +39,6 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
"user": user,
|
||||
"password": password,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="pgvector_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="pgvector_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,27 +16,15 @@ from pydantic import BaseModel, TypeAdapter
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
|
@ -206,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
|
|
@ -318,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
|
@ -342,18 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -410,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
|
@ -427,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
|
|
|||
|
|
@ -4,27 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -34,7 +31,6 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY:=}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="qdrant_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -25,17 +24,12 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
|
|
@ -100,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
|
|
@ -134,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
|
|
@ -162,17 +150,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None = None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.vector_db_store = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -187,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
QdrantIndex(self.client, vector_db.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
|
@ -200,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(self.client, vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
|
@ -243,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -256,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -4,27 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
||||
):
|
||||
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = WeaviateVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -22,16 +19,11 @@ class WeaviateVectorIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"weaviate_api_key": None,
|
||||
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="weaviate_registry.db",
|
||||
__distro_dir__=__distro_dir__, db_name="weaviate_registry.db"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,18 +16,14 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
|
|
@ -49,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -109,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
|
|
@ -154,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
|
|
@ -176,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
|
|
@ -275,25 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
OpenAIVectorStoreMixin,
|
||||
VectorIO,
|
||||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
|
|
@ -304,10 +272,7 @@ class WeaviateVectorIOAdapter(
|
|||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
client = weaviate.connect_to_local(host=host, port=port)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
|
|
@ -337,15 +302,9 @@ class WeaviateVectorIOAdapter(
|
|||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(
|
||||
client=client,
|
||||
collection_name=vector_db.identifier,
|
||||
kvstore=self.kvstore,
|
||||
)
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
vector_db=vector_db, index=idx, inference_api=self.inference_api
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
|
@ -357,10 +316,7 @@ class WeaviateVectorIOAdapter(
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
# Create collection if it doesn't exist
|
||||
|
|
@ -369,17 +325,12 @@ class WeaviateVectorIOAdapter(
|
|||
name=sanitized_collection_name,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
|
@ -415,12 +366,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
|
@ -428,10 +374,7 @@ class WeaviateVectorIOAdapter(
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||
from llama_stack.apis.models import Model, Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
|
@ -44,7 +43,6 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -90,9 +88,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore = kvstore
|
||||
# These will be set by implementing classes
|
||||
self.models_api: Models | None = None
|
||||
self.vector_stores_config: VectorStoresConfig | None = None
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
|
|
@ -398,21 +393,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if embedding_model is None:
|
||||
result = await self._get_default_embedding_model_and_dimension()
|
||||
if result is None:
|
||||
raise ValueError(
|
||||
"embedding_model is required in extra_body when creating a vector store. "
|
||||
"No default embedding model could be determined automatically."
|
||||
)
|
||||
embedding_model, embedding_dimension = result
|
||||
elif embedding_dimension is None:
|
||||
# Embedding model was provided but dimension wasn't, look it up
|
||||
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(
|
||||
f"Could not determine embedding dimension for model '{embedding_model}'. "
|
||||
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
|
||||
)
|
||||
raise ValueError("embedding_model is required")
|
||||
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
|
@ -479,64 +460,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
|
||||
"""Get embedding dimension for a specific model by looking it up in the models API.
|
||||
|
||||
Args:
|
||||
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
|
||||
|
||||
Returns:
|
||||
The embedding dimension for the model, or None if not found
|
||||
"""
|
||||
if not self.models_api:
|
||||
return None
|
||||
|
||||
models_response = await self.models_api.list_models()
|
||||
models_list = models_response.data if hasattr(models_response, "data") else models_response
|
||||
|
||||
for model in models_list:
|
||||
if not isinstance(model, Model):
|
||||
continue
|
||||
if model.model_type != "embedding":
|
||||
continue
|
||||
|
||||
# Check for exact match first
|
||||
if model.identifier == model_id:
|
||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is not None:
|
||||
return int(embedding_dimension)
|
||||
else:
|
||||
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
|
||||
return None
|
||||
|
||||
# Check for prefixed/unprefixed variations
|
||||
# If model_id is unprefixed, check if it matches the resource_id
|
||||
if model.provider_resource_id == model_id:
|
||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is not None:
|
||||
return int(embedding_dimension)
|
||||
|
||||
return None
|
||||
|
||||
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
|
||||
"""Get default embedding model from vector stores config.
|
||||
|
||||
Returns None if no vector stores config is provided.
|
||||
"""
|
||||
if not self.vector_stores_config:
|
||||
logger.info("No vector stores config provided")
|
||||
return None
|
||||
|
||||
model_id = self.vector_stores_config.default_embedding_model_id
|
||||
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{model_id}' not found or has no embedding_dimension in metadata")
|
||||
|
||||
logger.debug(
|
||||
f"Using default embedding model from vector stores config: {model_id} with dimension {embedding_dimension}"
|
||||
)
|
||||
return model_id, embedding_dimension
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session):
|
|||
# --stack-config bypasses template so need this to set default embedding model
|
||||
if "vector_io" in config and "inference" in config:
|
||||
run_config.vector_stores = VectorStoresConfig(
|
||||
default_embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
)
|
||||
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
|
|
|
|||
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_filters_single_api():
|
||||
"""Test filtering keeps only specified provider for one API."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
||||
|
||||
|
||||
def test_filters_multiple_apis():
|
||||
"""Test filtering multiple APIs."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
BuildProvider(provider_type="remote::anthropic"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
||||
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
||||
|
||||
|
||||
def test_provider_not_found_exits():
|
||||
"""Test error when specified provider doesn't exist."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
||||
|
||||
|
||||
def test_invalid_format_exits():
|
||||
"""Test error for invalid filter format."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "invalid_format")
|
||||
|
|
@ -144,7 +144,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
|
|
@ -183,7 +182,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
|
@ -76,12 +75,6 @@ def mock_files_api():
|
|||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_api():
|
||||
mock_api = MagicMock(spec=Models)
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config():
|
||||
config = MagicMock(spec=FaissVectorIOConfig)
|
||||
|
|
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
||||
|
||||
async def test_health_success(mock_models_api):
|
||||
async def test_health_success():
|
||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
|
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
|
|||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.return_value = MagicMock()
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
|
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
|
|||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
async def test_health_failure(mock_models_api):
|
||||
async def test_health_failure():
|
||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||
config = MagicMock()
|
||||
|
|
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
|
|||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
mock_index_flat.side_effect = Exception("Test error")
|
||||
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
# Calling the health method directly
|
||||
response = await adapter.health()
|
||||
|
|
|
|||
|
|
@ -1162,5 +1162,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
# Test with no embedding model provided
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||
|
||||
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue