mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
update resolver to only pass vector_stores section of run config
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> Using Router only from VectorDBs Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> removing model_api from vector store providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix test Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating integration tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
24a1430c8b
commit
accc4c437e
46 changed files with 397 additions and 702 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
- name: Check Storage and Memory Available Before Tests
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
|
|
@ -168,7 +168,7 @@ jobs:
|
||||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync \
|
uv run --no-sync \
|
||||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
||||||
tests/integration/vector_io
|
tests/integration/vector_io
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,85 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||||
|
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||||
|
# Parse the single-provider argument using the same logic as --providers
|
||||||
|
provider_filters: dict[str, str] = {}
|
||||||
|
for api_provider in single_provider_arg.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider_type = api_provider.split("=")
|
||||||
|
provider_filters[api] = provider_type
|
||||||
|
|
||||||
|
# Create a copy of the build config to modify
|
||||||
|
filtered_build_config = BuildConfig(
|
||||||
|
image_type=build_config.image_type,
|
||||||
|
image_name=build_config.image_name,
|
||||||
|
external_providers_dir=build_config.external_providers_dir,
|
||||||
|
external_apis_dir=build_config.external_apis_dir,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={},
|
||||||
|
description=build_config.distribution_spec.description,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy all providers, but filter the specified APIs
|
||||||
|
for api, providers in build_config.distribution_spec.providers.items():
|
||||||
|
if api in provider_filters:
|
||||||
|
target_provider_type = provider_filters[api]
|
||||||
|
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||||
|
if not filtered_providers:
|
||||||
|
cprint(
|
||||||
|
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||||
|
else:
|
||||||
|
# Keep all providers for unfiltered APIs
|
||||||
|
filtered_build_config.distribution_spec.providers[api] = providers
|
||||||
|
|
||||||
|
return filtered_build_config
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_filtered_run_config(
|
||||||
|
build_config: BuildConfig,
|
||||||
|
build_dir: Path,
|
||||||
|
distro_name: str,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||||
|
and filtering the providers according to the build_config.
|
||||||
|
"""
|
||||||
|
# Load the original distribution's run.yaml
|
||||||
|
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||||
|
|
||||||
|
with importlib.resources.as_file(distro_resource) as path:
|
||||||
|
with open(path) as f:
|
||||||
|
original_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Apply provider filtering to the loaded config
|
||||||
|
for api, providers in build_config.distribution_spec.providers.items():
|
||||||
|
if api in original_config.get("providers", {}):
|
||||||
|
# Filter this API to only include the providers from build_config
|
||||||
|
provider_types = {p.provider_type for p in providers}
|
||||||
|
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||||
|
original_config["providers"][api] = filtered_providers
|
||||||
|
|
||||||
|
# Write the filtered config
|
||||||
|
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||||
|
with open(run_config_file, "w") as f:
|
||||||
|
yaml.dump(original_config, f, sort_keys=False)
|
||||||
|
|
||||||
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -93,6 +172,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_distros[distro_name]
|
build_config = available_distros[distro_name]
|
||||||
|
|
||||||
|
# Apply single-provider filtering if specified
|
||||||
|
if args.single_provider:
|
||||||
|
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||||
|
|
||||||
if args.image_type:
|
if args.image_type:
|
||||||
build_config.image_type = args.image_type
|
build_config.image_type = args.image_type
|
||||||
else:
|
else:
|
||||||
|
|
@ -245,6 +329,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
config_path=args.config,
|
config_path=args.config,
|
||||||
distro_name=distro_name,
|
distro_name=distro_name,
|
||||||
|
is_filtered=bool(args.single_provider),
|
||||||
)
|
)
|
||||||
|
|
||||||
except (Exception, RuntimeError) as exc:
|
except (Exception, RuntimeError) as exc:
|
||||||
|
|
@ -363,6 +448,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
distro_name: str | None = None,
|
distro_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
|
is_filtered: bool = False,
|
||||||
) -> Path | Traversable:
|
) -> Path | Traversable:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
|
@ -435,12 +521,19 @@ def _run_stack_build_command_from_build_config(
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
||||||
if distro_name:
|
if distro_name:
|
||||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
# If single-provider filtering was applied, generate a filtered run config
|
||||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
# Otherwise, copy run.yaml from distribution as before
|
||||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
if is_filtered:
|
||||||
|
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||||
|
distro_path = run_config_file # Use the generated file as the distro_path
|
||||||
|
else:
|
||||||
|
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||||
|
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||||
|
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||||
|
|
||||||
with importlib.resources.as_file(distro_path) as path:
|
with importlib.resources.as_file(distro_resource) as path:
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found.
|
||||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--single-provider",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
# can be fast to load and reduces dependencies
|
# can be fast to load and reduces dependencies
|
||||||
|
|
|
||||||
|
|
@ -409,10 +409,6 @@ async def instantiate_provider(
|
||||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||||
args.append(run_config.telemetry.enabled)
|
args.append(run_config.telemetry.enabled)
|
||||||
|
|
||||||
# vector_io providers need access to run_config.vector_stores
|
|
||||||
if provider_spec.api == Api.vector_io and "run_config" in inspect.signature(getattr(module, method)).parameters:
|
|
||||||
args.append(run_config)
|
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
impl.__provider_id__ = provider.provider_id
|
impl.__provider_id__ = provider.provider_id
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,9 @@ async def get_auto_router_impl(
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
|
if api == Api.vector_io and run_config.vector_stores:
|
||||||
|
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.datatypes import VectorStoresConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
|
|
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
vector_stores_config: VectorStoresConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing VectorIORouter")
|
logger.debug("Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.vector_stores_config = vector_stores_config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("VectorIORouter.initialize")
|
logger.debug("VectorIORouter.initialize")
|
||||||
|
|
@ -122,6 +125,10 @@ class VectorIORouter(VectorIO):
|
||||||
embedding_dimension = extra.get("embedding_dimension")
|
embedding_dimension = extra.get("embedding_dimension")
|
||||||
provider_id = extra.get("provider_id")
|
provider_id = extra.get("provider_id")
|
||||||
|
|
||||||
|
if embedding_model is None and self.vector_stores_config is not None:
|
||||||
|
embedding_model = self.vector_stores_config.default_embedding_model_id
|
||||||
|
logger.debug(f"Using default embedding model: {embedding_model}")
|
||||||
|
|
||||||
if embedding_model is not None and embedding_dimension is None:
|
if embedding_model is not None and embedding_dimension is None:
|
||||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::qdrant
|
||||||
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::qdrant
|
||||||
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
@ -240,7 +255,7 @@ tool_groups:
|
||||||
provider_id: rag-runtime
|
provider_id: rag-runtime
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
vector_stores:
|
|
||||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
|
||||||
telemetry:
|
telemetry:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
vector_stores:
|
||||||
|
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::qdrant
|
||||||
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -114,6 +116,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="inline::milvus"),
|
BuildProvider(provider_type="inline::milvus"),
|
||||||
BuildProvider(provider_type="remote::chromadb"),
|
BuildProvider(provider_type="remote::chromadb"),
|
||||||
BuildProvider(provider_type="remote::pgvector"),
|
BuildProvider(provider_type="remote::pgvector"),
|
||||||
|
BuildProvider(provider_type="remote::qdrant"),
|
||||||
|
BuildProvider(provider_type="remote::weaviate"),
|
||||||
],
|
],
|
||||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||||
"safety": [
|
"safety": [
|
||||||
|
|
@ -222,6 +226,22 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||||
|
provider_type="remote::qdrant",
|
||||||
|
config=QdrantVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
url="${env.QDRANT_URL:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
|
config=WeaviateVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -6,29 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: ChromaVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||||
):
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
|
||||||
ChromaVectorIOAdapter,
|
|
||||||
)
|
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = ChromaVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ class ChromaVectorIOConfig(BaseModel):
|
||||||
return {
|
return {
|
||||||
"db_path": db_path,
|
"db_path": db_path,
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="chroma_inline_registry.db"
|
||||||
db_name="chroma_inline_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,29 +6,16 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: FaissVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .faiss import FaissVectorIOAdapter
|
from .faiss import FaissVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = FaissVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,8 +19,5 @@ class FaissVectorIOConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(__distro_dir__=__distro_dir__, db_name="faiss_store.db")
|
||||||
__distro_dir__=__distro_dir__,
|
|
||||||
db_name="faiss_store.db",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,28 +17,14 @@ from numpy.typing import NDArray
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
Chunk,
|
|
||||||
QueryChunksResponse,
|
|
||||||
VectorIO,
|
|
||||||
)
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
|
||||||
HealthResponse,
|
|
||||||
HealthStatus,
|
|
||||||
VectorDBsProtocolPrivate,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||||
ChunkForDeletion,
|
|
||||||
EmbeddingIndex,
|
|
||||||
VectorDBWithIndex,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
@ -156,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
async def query_vector(
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
embedding: NDArray,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
|
|
@ -176,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||||
)
|
)
|
||||||
|
|
@ -201,19 +177,10 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self,
|
|
||||||
config: FaissVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
models_api: Models,
|
|
||||||
files_api: Files | None,
|
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
@ -255,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
self,
|
|
||||||
vector_db: VectorDB,
|
|
||||||
) -> None:
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
|
|
||||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||||
key=key,
|
|
||||||
value=vector_db.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
|
|
@ -288,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = self.cache.get(vector_db_id)
|
index = self.cache.get(vector_db_id)
|
||||||
if index is None:
|
if index is None:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||||
|
|
@ -301,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = self.cache.get(vector_db_id)
|
index = self.cache.get(vector_db_id)
|
||||||
if index is None:
|
if index is None:
|
||||||
|
|
|
||||||
|
|
@ -6,27 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: MilvusVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = MilvusVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps.get(Api.models),
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,7 +23,6 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
return {
|
return {
|
||||||
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
|
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="milvus_registry.db"
|
||||||
db_name="milvus_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,28 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: QdrantVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = QdrantVectorIOAdapter(
|
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,28 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import SQLiteVectorIOConfig
|
from .config import SQLiteVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: SQLiteVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = SQLiteVecVectorIOAdapter(
|
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVectorIOConfig(BaseModel):
|
class SQLiteVectorIOConfig(BaseModel):
|
||||||
|
|
@ -23,7 +20,6 @@ class SQLiteVectorIOConfig(BaseModel):
|
||||||
return {
|
return {
|
||||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="sqlite_vec_registry.db"
|
||||||
db_name="sqlite_vec_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,8 @@ from numpy.typing import NDArray
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
Chunk,
|
|
||||||
QueryChunksResponse,
|
|
||||||
VectorIO,
|
|
||||||
)
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
@ -176,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
|
|
||||||
# Insert vector embeddings
|
# Insert vector embeddings
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
(
|
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||||
(
|
|
||||||
chunk.chunk_id,
|
|
||||||
serialize_vector(emb.tolist()),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
]
|
]
|
||||||
cur.executemany(
|
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
|
||||||
embedding_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert FTS content
|
# Insert FTS content
|
||||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||||
cur.executemany(
|
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
|
||||||
[(row[0],) for row in fts_data],
|
|
||||||
)
|
|
||||||
|
|
||||||
# INSERT new entries
|
# INSERT new entries
|
||||||
cur.executemany(
|
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
|
||||||
fts_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
@ -217,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
# Run batch insertion in a background thread
|
# Run batch insertion in a background thread
|
||||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||||
|
|
||||||
async def query_vector(
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
embedding: NDArray,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Performs vector-based search using a virtual table for vector similarity.
|
Performs vector-based search using a virtual table for vector similarity.
|
||||||
"""
|
"""
|
||||||
|
|
@ -262,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||||
"""
|
"""
|
||||||
|
|
@ -411,19 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self,
|
|
||||||
config,
|
|
||||||
inference_api: Inference,
|
|
||||||
models_api: Models,
|
|
||||||
files_api: Files | None,
|
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
|
|
||||||
|
|
@ -436,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
for db_json in stored_vector_dbs:
|
for db_json in stored_vector_dbs:
|
||||||
vector_db = VectorDB.model_validate_json(db_json)
|
vector_db = VectorDB.model_validate_json(db_json)
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(
|
||||||
vector_db.embedding_dimension,
|
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||||
self.config.db_path,
|
|
||||||
vector_db.identifier,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
|
|
@ -453,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
return [v.vector_db for v in self.cache.values()]
|
return [v.vector_db for v in self.cache.values()]
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||||
vector_db.embedding_dimension,
|
|
||||||
self.config.db_path,
|
|
||||||
vector_db.identifier,
|
|
||||||
)
|
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||||
|
|
|
||||||
|
|
@ -4,27 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .chroma import ChromaVectorIOAdapter
|
from .chroma import ChromaVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = ChromaVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -13,25 +13,15 @@ from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
Chunk,
|
|
||||||
QueryChunksResponse,
|
|
||||||
VectorIO,
|
|
||||||
)
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||||
ChunkForDeletion,
|
|
||||||
EmbeddingIndex,
|
|
||||||
VectorDBWithIndex,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
@ -70,19 +60,13 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||||
await maybe_await(
|
await maybe_await(
|
||||||
self.collection.add(
|
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
|
||||||
embeddings=embeddings,
|
|
||||||
ids=ids,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
results = await maybe_await(
|
results = await maybe_await(
|
||||||
self.collection.query(
|
self.collection.query(
|
||||||
query_embeddings=[embedding.tolist()],
|
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||||
n_results=k,
|
|
||||||
include=["documents", "distances"],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
distances = results["distances"][0]
|
distances = results["distances"][0]
|
||||||
|
|
@ -110,12 +94,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
|
@ -140,16 +119,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self,
|
self,
|
||||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
models_apis: Models,
|
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_apis
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
|
|
@ -176,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
# Clean up mixin resources (file batch tasks)
|
# Clean up mixin resources (file batch tasks)
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
self,
|
|
||||||
vector_db: VectorDB,
|
|
||||||
) -> None:
|
|
||||||
collection = await maybe_await(
|
collection = await maybe_await(
|
||||||
self.client.get_or_create_collection(
|
self.client.get_or_create_collection(
|
||||||
name=vector_db.identifier,
|
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
|
||||||
metadata={"vector_db": vector_db.model_dump_json()},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
|
|
@ -198,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if index is None:
|
if index is None:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||||
|
|
@ -211,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ class ChromaVectorIOConfig(BaseModel):
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"url": url,
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="chroma_remote_registry.db"
|
||||||
db_name="chroma_remote_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,28 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .milvus import MilvusVectorIOAdapter
|
from .milvus import MilvusVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = MilvusVectorIOAdapter(
|
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
"uri": "${env.MILVUS_ENDPOINT}",
|
"uri": "${env.MILVUS_ENDPOINT}",
|
||||||
"token": "${env.MILVUS_TOKEN}",
|
"token": "${env.MILVUS_TOKEN}",
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="milvus_remote_registry.db"
|
||||||
db_name="milvus_remote_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
Chunk,
|
|
||||||
QueryChunksResponse,
|
|
||||||
VectorIO,
|
|
||||||
)
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
|
|
@ -75,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
# Create schema for vector search
|
# Create schema for vector search
|
||||||
schema = self.client.create_schema()
|
schema = self.client.create_schema()
|
||||||
schema.add_field(
|
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||||
field_name="chunk_id",
|
|
||||||
datatype=DataType.VARCHAR,
|
|
||||||
is_primary=True,
|
|
||||||
max_length=100,
|
|
||||||
)
|
|
||||||
schema.add_field(
|
schema.add_field(
|
||||||
field_name="content",
|
field_name="content",
|
||||||
datatype=DataType.VARCHAR,
|
datatype=DataType.VARCHAR,
|
||||||
max_length=65535,
|
max_length=65535,
|
||||||
enable_analyzer=True, # Enable text analysis for BM25
|
enable_analyzer=True, # Enable text analysis for BM25
|
||||||
)
|
)
|
||||||
schema.add_field(
|
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||||
field_name="vector",
|
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||||
datatype=DataType.FLOAT_VECTOR,
|
|
||||||
dim=len(embeddings[0]),
|
|
||||||
)
|
|
||||||
schema.add_field(
|
|
||||||
field_name="chunk_content",
|
|
||||||
datatype=DataType.JSON,
|
|
||||||
)
|
|
||||||
# Add sparse vector field for BM25 (required by the function)
|
# Add sparse vector field for BM25 (required by the function)
|
||||||
schema.add_field(
|
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||||
field_name="sparse",
|
|
||||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create indexes
|
# Create indexes
|
||||||
index_params = self.client.prepare_index_params()
|
index_params = self.client.prepare_index_params()
|
||||||
index_params.add_index(
|
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||||
field_name="vector",
|
|
||||||
index_type="FLAT",
|
|
||||||
metric_type="COSINE",
|
|
||||||
)
|
|
||||||
# Add index for sparse field (required by BM25 function)
|
# Add index for sparse field (required by BM25 function)
|
||||||
index_params.add_index(
|
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||||
field_name="sparse",
|
|
||||||
index_type="SPARSE_INVERTED_INDEX",
|
|
||||||
metric_type="BM25",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add BM25 function for full-text search
|
# Add BM25 function for full-text search
|
||||||
bm25_function = Function(
|
bm25_function = Function(
|
||||||
|
|
@ -145,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||||
self.client.insert,
|
|
||||||
self.collection_name,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -168,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
scores = [res["distance"] for res in search_res[0]]
|
scores = [res["distance"] for res in search_res[0]]
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||||
"""
|
"""
|
||||||
|
|
@ -211,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
# Fallback to simple text search
|
# Fallback to simple text search
|
||||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
async def _fallback_keyword_search(
|
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Fallback to simple text search when BM25 search is not available.
|
Fallback to simple text search when BM25 search is not available.
|
||||||
"""
|
"""
|
||||||
|
|
@ -309,17 +266,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self,
|
self,
|
||||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
models_api: Models | None,
|
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.client = None
|
self.client = None
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
|
|
@ -358,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
# Clean up mixin resources (file batch tasks)
|
# Clean up mixin resources (file batch tasks)
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
self,
|
|
||||||
vector_db: VectorDB,
|
|
||||||
) -> None:
|
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
consistency_level = self.config.consistency_level
|
consistency_level = self.config.consistency_level
|
||||||
else:
|
else:
|
||||||
|
|
@ -398,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
@ -411,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
|
|
|
||||||
|
|
@ -4,26 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .pgvector import PGVectorVectorIOAdapter
|
from .pgvector import PGVectorVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
impl = PGVectorVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,7 +39,6 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
"user": user,
|
"user": user,
|
||||||
"password": password,
|
"password": password,
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="pgvector_registry.db"
|
||||||
db_name="pgvector_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,27 +16,15 @@ from pydantic import BaseModel, TypeAdapter
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
Chunk,
|
|
||||||
QueryChunksResponse,
|
|
||||||
VectorIO,
|
|
||||||
)
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||||
ChunkForDeletion,
|
|
||||||
EmbeddingIndex,
|
|
||||||
VectorDBWithIndex,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
@ -206,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||||
|
|
||||||
|
|
@ -318,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
"""Remove a chunk from the PostgreSQL table."""
|
"""Remove a chunk from the PostgreSQL table."""
|
||||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
|
||||||
|
|
||||||
def get_pgvector_search_function(self) -> str:
|
def get_pgvector_search_function(self) -> str:
|
||||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||||
|
|
@ -342,18 +325,11 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||||
config: PGVectorVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
models_api: Models,
|
|
||||||
files_api: Files | None = None,
|
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
|
|
@ -410,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||||
)
|
)
|
||||||
await pgvector_index.initialize()
|
await pgvector_index.initialize()
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
|
||||||
vector_db,
|
|
||||||
index=pgvector_index,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
|
@ -427,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
|
||||||
|
|
@ -4,27 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .qdrant import QdrantVectorIOAdapter
|
from .qdrant import QdrantVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = QdrantVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,7 +31,6 @@ class QdrantVectorIOConfig(BaseModel):
|
||||||
return {
|
return {
|
||||||
"api_key": "${env.QDRANT_API_KEY:=}",
|
"api_key": "${env.QDRANT_API_KEY:=}",
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
|
||||||
db_name="qdrant_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
|
|
@ -25,17 +24,12 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreChunkingStrategy,
|
VectorStoreChunkingStrategy,
|
||||||
VectorStoreFileObject,
|
VectorStoreFileObject,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||||
ChunkForDeletion,
|
|
||||||
EmbeddingIndex,
|
|
||||||
VectorDBWithIndex,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
|
|
@ -100,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||||
try:
|
try:
|
||||||
await self.client.delete(
|
await self.client.delete(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
|
||||||
points_selector=models.PointIdsList(points=chunk_ids),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||||
|
|
@ -134,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
|
|
@ -162,17 +150,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self,
|
self,
|
||||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
models_api: Models,
|
|
||||||
files_api: Files | None = None,
|
files_api: Files | None = None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client: AsyncQdrantClient = None
|
self.client: AsyncQdrantClient = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
self._qdrant_lock = asyncio.Lock()
|
self._qdrant_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
@ -187,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
for vector_db_data in stored_vector_dbs:
|
for vector_db_data in stored_vector_dbs:
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
|
||||||
vector_db,
|
|
||||||
QdrantIndex(self.client, vector_db.identifier),
|
|
||||||
self.inference_api,
|
|
||||||
)
|
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
|
|
@ -200,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
# Clean up mixin resources (file batch tasks)
|
# Clean up mixin resources (file batch tasks)
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
self,
|
|
||||||
vector_db: VectorDB,
|
|
||||||
) -> None:
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||||
|
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
|
||||||
index=QdrantIndex(self.client, vector_db.identifier),
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
@ -243,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
@ -256,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
|
|
|
||||||
|
|
@ -4,27 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None
|
|
||||||
):
|
|
||||||
from .weaviate import WeaviateVectorIOAdapter
|
from .weaviate import WeaviateVectorIOAdapter
|
||||||
|
|
||||||
vector_stores_config = None
|
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
if run_config and run_config.vector_stores:
|
|
||||||
vector_stores_config = run_config.vector_stores
|
|
||||||
|
|
||||||
impl = WeaviateVectorIOAdapter(
|
|
||||||
config,
|
|
||||||
deps[Api.inference],
|
|
||||||
deps[Api.models],
|
|
||||||
deps.get(Api.files),
|
|
||||||
vector_stores_config,
|
|
||||||
)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
KVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,16 +19,11 @@ class WeaviateVectorIOConfig(BaseModel):
|
||||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
cls,
|
|
||||||
__distro_dir__: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
return {
|
return {
|
||||||
"weaviate_api_key": None,
|
"weaviate_api_key": None,
|
||||||
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
||||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__, db_name="weaviate_registry.db"
|
||||||
db_name="weaviate_registry.db",
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,18 +16,14 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
OpenAIVectorStoreMixin,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
RERANKER_TYPE_RRF,
|
RERANKER_TYPE_RRF,
|
||||||
ChunkForDeletion,
|
ChunkForDeletion,
|
||||||
|
|
@ -49,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
||||||
|
|
||||||
|
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||||
self,
|
|
||||||
client: weaviate.WeaviateClient,
|
|
||||||
collection_name: str,
|
|
||||||
kvstore: KVStore | None = None,
|
|
||||||
):
|
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
|
@ -109,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
near_vector=embedding.tolist(),
|
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
|
||||||
limit=k,
|
|
||||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Weaviate client vector search failed: {e}")
|
log.error(f"Weaviate client vector search failed: {e}")
|
||||||
|
|
@ -154,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
collection = self.client.collections.get(sanitized_collection_name)
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||||
|
|
||||||
async def query_keyword(
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
self,
|
|
||||||
query_string: str,
|
|
||||||
k: int,
|
|
||||||
score_threshold: float,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
"""
|
"""
|
||||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -176,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# Perform BM25 keyword search on chunk_content field
|
# Perform BM25 keyword search on chunk_content field
|
||||||
try:
|
try:
|
||||||
results = collection.query.bm25(
|
results = collection.query.bm25(
|
||||||
query=query_string,
|
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
|
||||||
limit=k,
|
|
||||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Weaviate client keyword search failed: {e}")
|
log.error(f"Weaviate client keyword search failed: {e}")
|
||||||
|
|
@ -275,25 +257,11 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
|
||||||
OpenAIVectorStoreMixin,
|
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
VectorIO,
|
|
||||||
NeedsRequestProviderData,
|
|
||||||
VectorDBsProtocolPrivate,
|
|
||||||
):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: WeaviateVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
models_api: Models,
|
|
||||||
files_api: Files | None,
|
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.models_api = models_api
|
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_db_store = None
|
self.vector_db_store = None
|
||||||
|
|
@ -304,10 +272,7 @@ class WeaviateVectorIOAdapter(
|
||||||
log.info("Using Weaviate locally in container")
|
log.info("Using Weaviate locally in container")
|
||||||
host, port = self.config.weaviate_cluster_url.split(":")
|
host, port = self.config.weaviate_cluster_url.split(":")
|
||||||
key = "local_test"
|
key = "local_test"
|
||||||
client = weaviate.connect_to_local(
|
client = weaviate.connect_to_local(host=host, port=port)
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log.info("Using Weaviate remote cluster with URL")
|
log.info("Using Weaviate remote cluster with URL")
|
||||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||||
|
|
@ -337,15 +302,9 @@ class WeaviateVectorIOAdapter(
|
||||||
for raw in stored:
|
for raw in stored:
|
||||||
vector_db = VectorDB.model_validate_json(raw)
|
vector_db = VectorDB.model_validate_json(raw)
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
idx = WeaviateIndex(
|
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||||
client=client,
|
|
||||||
collection_name=vector_db.identifier,
|
|
||||||
kvstore=self.kvstore,
|
|
||||||
)
|
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db, index=idx, inference_api=self.inference_api
|
||||||
index=idx,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load OpenAI vector stores metadata into cache
|
# Load OpenAI vector stores metadata into cache
|
||||||
|
|
@ -357,10 +316,7 @@ class WeaviateVectorIOAdapter(
|
||||||
# Clean up mixin resources (file batch tasks)
|
# Clean up mixin resources (file batch tasks)
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
self,
|
|
||||||
vector_db: VectorDB,
|
|
||||||
) -> None:
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||||
# Create collection if it doesn't exist
|
# Create collection if it doesn't exist
|
||||||
|
|
@ -369,17 +325,12 @@ class WeaviateVectorIOAdapter(
|
||||||
name=sanitized_collection_name,
|
name=sanitized_collection_name,
|
||||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||||
properties=[
|
properties=[
|
||||||
wvc.config.Property(
|
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
|
||||||
name="chunk_content",
|
|
||||||
data_type=wvc.config.DataType.TEXT,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
vector_db,
|
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
|
||||||
self.inference_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
|
@ -415,12 +366,7 @@ class WeaviateVectorIOAdapter(
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise VectorStoreNotFoundError(vector_db_id)
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
@ -428,10 +374,7 @@ class WeaviateVectorIOAdapter(
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||||
from llama_stack.apis.models import Model, Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
|
|
@ -44,7 +43,6 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.id_generation import generate_object_id
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
|
@ -90,9 +88,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||||
self.files_api = files_api
|
self.files_api = files_api
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
# These will be set by implementing classes
|
|
||||||
self.models_api: Models | None = None
|
|
||||||
self.vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
self._last_file_batch_cleanup_time = 0
|
self._last_file_batch_cleanup_time = 0
|
||||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
|
||||||
|
|
@ -398,21 +393,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||||
|
|
||||||
if embedding_model is None:
|
if embedding_model is None:
|
||||||
result = await self._get_default_embedding_model_and_dimension()
|
raise ValueError("embedding_model is required")
|
||||||
if result is None:
|
|
||||||
raise ValueError(
|
|
||||||
"embedding_model is required in extra_body when creating a vector store. "
|
|
||||||
"No default embedding model could be determined automatically."
|
|
||||||
)
|
|
||||||
embedding_model, embedding_dimension = result
|
|
||||||
elif embedding_dimension is None:
|
|
||||||
# Embedding model was provided but dimension wasn't, look it up
|
|
||||||
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not determine embedding dimension for model '{embedding_model}'. "
|
|
||||||
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
|
|
||||||
)
|
|
||||||
|
|
||||||
if embedding_dimension is None:
|
if embedding_dimension is None:
|
||||||
raise ValueError("Embedding dimension is required")
|
raise ValueError("Embedding dimension is required")
|
||||||
|
|
@ -479,64 +460,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
store_info = self.openai_vector_stores[vector_db_id]
|
store_info = self.openai_vector_stores[vector_db_id]
|
||||||
return VectorStoreObject.model_validate(store_info)
|
return VectorStoreObject.model_validate(store_info)
|
||||||
|
|
||||||
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
|
|
||||||
"""Get embedding dimension for a specific model by looking it up in the models API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The embedding dimension for the model, or None if not found
|
|
||||||
"""
|
|
||||||
if not self.models_api:
|
|
||||||
return None
|
|
||||||
|
|
||||||
models_response = await self.models_api.list_models()
|
|
||||||
models_list = models_response.data if hasattr(models_response, "data") else models_response
|
|
||||||
|
|
||||||
for model in models_list:
|
|
||||||
if not isinstance(model, Model):
|
|
||||||
continue
|
|
||||||
if model.model_type != "embedding":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check for exact match first
|
|
||||||
if model.identifier == model_id:
|
|
||||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
|
||||||
if embedding_dimension is not None:
|
|
||||||
return int(embedding_dimension)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check for prefixed/unprefixed variations
|
|
||||||
# If model_id is unprefixed, check if it matches the resource_id
|
|
||||||
if model.provider_resource_id == model_id:
|
|
||||||
embedding_dimension = model.metadata.get("embedding_dimension")
|
|
||||||
if embedding_dimension is not None:
|
|
||||||
return int(embedding_dimension)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
|
|
||||||
"""Get default embedding model from vector stores config.
|
|
||||||
|
|
||||||
Returns None if no vector stores config is provided.
|
|
||||||
"""
|
|
||||||
if not self.vector_stores_config:
|
|
||||||
logger.info("No vector stores config provided")
|
|
||||||
return None
|
|
||||||
|
|
||||||
model_id = self.vector_stores_config.default_embedding_model_id
|
|
||||||
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError(f"Embedding model '{model_id}' not found or has no embedding_dimension in metadata")
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Using default embedding model from vector stores config: {model_id} with dimension {embedding_dimension}"
|
|
||||||
)
|
|
||||||
return model_id, embedding_dimension
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int | None = 20,
|
limit: int | None = 20,
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session):
|
||||||
# --stack-config bypasses template so need this to set default embedding model
|
# --stack-config bypasses template so need this to set default embedding model
|
||||||
if "vector_io" in config and "inference" in config:
|
if "vector_io" in config and "inference" in config:
|
||||||
run_config.vector_stores = VectorStoresConfig(
|
run_config.vector_stores = VectorStoresConfig(
|
||||||
default_embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||||
)
|
)
|
||||||
|
|
||||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||||
|
|
|
||||||
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
||||||
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||||
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_single_api():
|
||||||
|
"""Test filtering keeps only specified provider for one API."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={
|
||||||
|
"vector_io": [
|
||||||
|
BuildProvider(provider_type="inline::faiss"),
|
||||||
|
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||||
|
],
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(provider_type="remote::openai"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
||||||
|
|
||||||
|
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
||||||
|
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_multiple_apis():
|
||||||
|
"""Test filtering multiple APIs."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={
|
||||||
|
"vector_io": [
|
||||||
|
BuildProvider(provider_type="inline::faiss"),
|
||||||
|
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||||
|
],
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(provider_type="remote::openai"),
|
||||||
|
BuildProvider(provider_type="remote::anthropic"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
||||||
|
|
||||||
|
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
||||||
|
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_not_found_exits():
|
||||||
|
"""Test error when specified provider doesn't exist."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_format_exits():
|
||||||
|
"""Test error for invalid filter format."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
_apply_single_provider_filter(build_config, "invalid_format")
|
||||||
|
|
@ -144,7 +144,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
||||||
config=config,
|
config=config,
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
files_api=None,
|
files_api=None,
|
||||||
models_api=None,
|
|
||||||
)
|
)
|
||||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
|
|
@ -183,7 +182,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
||||||
config=config,
|
config=config,
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
files_api=None,
|
files_api=None,
|
||||||
models_api=None,
|
|
||||||
)
|
)
|
||||||
await adapter.initialize()
|
await adapter.initialize()
|
||||||
await adapter.register_vector_db(
|
await adapter.register_vector_db(
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
@ -76,12 +75,6 @@ def mock_files_api():
|
||||||
return mock_api
|
return mock_api
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_models_api():
|
|
||||||
mock_api = MagicMock(spec=Models)
|
|
||||||
return mock_api
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def faiss_config():
|
def faiss_config():
|
||||||
config = MagicMock(spec=FaissVectorIOConfig)
|
config = MagicMock(spec=FaissVectorIOConfig)
|
||||||
|
|
@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
||||||
assert response.chunks[1] == sample_chunks[1]
|
assert response.chunks[1] == sample_chunks[1]
|
||||||
|
|
||||||
|
|
||||||
async def test_health_success(mock_models_api):
|
async def test_health_success():
|
||||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||||
config = MagicMock()
|
config = MagicMock()
|
||||||
|
|
@ -126,9 +119,7 @@ async def test_health_success(mock_models_api):
|
||||||
|
|
||||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||||
mock_index_flat.return_value = MagicMock()
|
mock_index_flat.return_value = MagicMock()
|
||||||
adapter = FaissVectorIOAdapter(
|
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calling the health method directly
|
# Calling the health method directly
|
||||||
response = await adapter.health()
|
response = await adapter.health()
|
||||||
|
|
@ -142,7 +133,7 @@ async def test_health_success(mock_models_api):
|
||||||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||||
|
|
||||||
|
|
||||||
async def test_health_failure(mock_models_api):
|
async def test_health_failure():
|
||||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||||
config = MagicMock()
|
config = MagicMock()
|
||||||
|
|
@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api):
|
||||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||||
mock_index_flat.side_effect = Exception("Test error")
|
mock_index_flat.side_effect = Exception("Test error")
|
||||||
|
|
||||||
adapter = FaissVectorIOAdapter(
|
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||||
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calling the health method directly
|
# Calling the health method directly
|
||||||
response = await adapter.health()
|
response = await adapter.health()
|
||||||
|
|
|
||||||
|
|
@ -1162,5 +1162,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||||
# Test with no embedding model provided
|
# Test with no embedding model provided
|
||||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"):
|
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
await vector_io_adapter.openai_create_vector_store(params)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue