mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
adding back relevant vector_db files
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating tests and fixing routing logic for single provider Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> setting default provider to update tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updated provider_id Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updated VectorStoreConfig to use (provider_id, embedding_model_id) and add defautl vector store provider Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
accc4c437e
commit
b3addc94d1
23 changed files with 637 additions and 261 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
- name: Check Storage and Memory Available Before Tests
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
|
|
@ -168,7 +168,7 @@ jobs:
|
||||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync \
|
uv run --no-sync \
|
||||||
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||||
tests/integration/vector_io
|
tests/integration/vector_io
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
|
|
||||||
models = "models"
|
models = "models"
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
|
vector_dbs = "vector_dbs" # only used for routing
|
||||||
datasets = "datasets"
|
datasets = "datasets"
|
||||||
scoring_functions = "scoring_functions"
|
scoring_functions = "scoring_functions"
|
||||||
benchmarks = "benchmarks"
|
benchmarks = "benchmarks"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -59,3 +59,35 @@ class ListVectorDBsResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[VectorDB]
|
data: list[VectorDB]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class VectorDBs(Protocol):
|
||||||
|
"""Internal protocol for vector_dbs routing - no public API endpoints."""
|
||||||
|
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
"""Internal method to list vector databases."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
) -> VectorDB:
|
||||||
|
"""Internal method to get a vector database by ID."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
vector_db_name: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorDB:
|
||||||
|
"""Internal method to register a vector database."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
"""Internal method to unregister a vector database."""
|
||||||
|
...
|
||||||
|
|
|
||||||
|
|
@ -50,85 +50,6 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||||
|
|
||||||
|
|
||||||
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
|
||||||
"""Filter a distribution to only include specified providers for certain APIs."""
|
|
||||||
# Parse the single-provider argument using the same logic as --providers
|
|
||||||
provider_filters: dict[str, str] = {}
|
|
||||||
for api_provider in single_provider_arg.split(","):
|
|
||||||
if "=" not in api_provider:
|
|
||||||
cprint(
|
|
||||||
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
api, provider_type = api_provider.split("=")
|
|
||||||
provider_filters[api] = provider_type
|
|
||||||
|
|
||||||
# Create a copy of the build config to modify
|
|
||||||
filtered_build_config = BuildConfig(
|
|
||||||
image_type=build_config.image_type,
|
|
||||||
image_name=build_config.image_name,
|
|
||||||
external_providers_dir=build_config.external_providers_dir,
|
|
||||||
external_apis_dir=build_config.external_apis_dir,
|
|
||||||
distribution_spec=DistributionSpec(
|
|
||||||
providers={},
|
|
||||||
description=build_config.distribution_spec.description,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy all providers, but filter the specified APIs
|
|
||||||
for api, providers in build_config.distribution_spec.providers.items():
|
|
||||||
if api in provider_filters:
|
|
||||||
target_provider_type = provider_filters[api]
|
|
||||||
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
|
||||||
if not filtered_providers:
|
|
||||||
cprint(
|
|
||||||
f"Provider {target_provider_type} not found in distribution for API {api}",
|
|
||||||
color="red",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
|
||||||
else:
|
|
||||||
# Keep all providers for unfiltered APIs
|
|
||||||
filtered_build_config.distribution_spec.providers[api] = providers
|
|
||||||
|
|
||||||
return filtered_build_config
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_filtered_run_config(
|
|
||||||
build_config: BuildConfig,
|
|
||||||
build_dir: Path,
|
|
||||||
distro_name: str,
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
|
||||||
and filtering the providers according to the build_config.
|
|
||||||
"""
|
|
||||||
# Load the original distribution's run.yaml
|
|
||||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
|
||||||
|
|
||||||
with importlib.resources.as_file(distro_resource) as path:
|
|
||||||
with open(path) as f:
|
|
||||||
original_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
# Apply provider filtering to the loaded config
|
|
||||||
for api, providers in build_config.distribution_spec.providers.items():
|
|
||||||
if api in original_config.get("providers", {}):
|
|
||||||
# Filter this API to only include the providers from build_config
|
|
||||||
provider_types = {p.provider_type for p in providers}
|
|
||||||
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
|
||||||
original_config["providers"][api] = filtered_providers
|
|
||||||
|
|
||||||
# Write the filtered config
|
|
||||||
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
|
||||||
with open(run_config_file, "w") as f:
|
|
||||||
yaml.dump(original_config, f, sort_keys=False)
|
|
||||||
|
|
||||||
return run_config_file
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -172,11 +93,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_distros[distro_name]
|
build_config = available_distros[distro_name]
|
||||||
|
|
||||||
# Apply single-provider filtering if specified
|
|
||||||
if args.single_provider:
|
|
||||||
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
|
||||||
|
|
||||||
if args.image_type:
|
if args.image_type:
|
||||||
build_config.image_type = args.image_type
|
build_config.image_type = args.image_type
|
||||||
else:
|
else:
|
||||||
|
|
@ -329,7 +245,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
config_path=args.config,
|
config_path=args.config,
|
||||||
distro_name=distro_name,
|
distro_name=distro_name,
|
||||||
is_filtered=bool(args.single_provider),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except (Exception, RuntimeError) as exc:
|
except (Exception, RuntimeError) as exc:
|
||||||
|
|
@ -448,7 +363,6 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
distro_name: str | None = None,
|
distro_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
is_filtered: bool = False,
|
|
||||||
) -> Path | Traversable:
|
) -> Path | Traversable:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
|
@ -521,19 +435,12 @@ def _run_stack_build_command_from_build_config(
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
||||||
if distro_name:
|
if distro_name:
|
||||||
# If single-provider filtering was applied, generate a filtered run config
|
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||||
# Otherwise, copy run.yaml from distribution as before
|
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||||
if is_filtered:
|
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||||
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
|
||||||
distro_path = run_config_file # Use the generated file as the distro_path
|
|
||||||
else:
|
|
||||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
|
||||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
|
||||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
|
||||||
|
|
||||||
with importlib.resources.as_file(distro_resource) as path:
|
with importlib.resources.as_file(distro_path) as path:
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
distro_path = run_config_file # Update distro_path to point to the copied file
|
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||||
|
|
|
||||||
|
|
@ -92,13 +92,6 @@ the build. If not specified, currently active environment will be used if found.
|
||||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--single-provider",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
# can be fast to load and reduces dependencies
|
# can be fast to load and reduces dependencies
|
||||||
|
|
|
||||||
|
|
@ -354,10 +354,14 @@ class AuthenticationRequiredError(Exception):
|
||||||
class VectorStoresConfig(BaseModel):
|
class VectorStoresConfig(BaseModel):
|
||||||
"""Configuration for vector stores in the stack."""
|
"""Configuration for vector stores in the stack."""
|
||||||
|
|
||||||
default_embedding_model_id: str = Field(
|
embedding_model_id: str = Field(
|
||||||
...,
|
...,
|
||||||
description="ID of the embedding model to use as default for vector stores when none is specified. Must reference a model defined in the 'models' section.",
|
description="ID of the embedding model to use as default for vector stores when none is specified. Must reference a model defined in the 'models' section.",
|
||||||
)
|
)
|
||||||
|
provider_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuotaPeriod(StrEnum):
|
class QuotaPeriod(StrEnum):
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
routing_table_api=Api.tool_groups,
|
routing_table_api=Api.tool_groups,
|
||||||
router_api=Api.tool_runtime,
|
router_api=Api.tool_runtime,
|
||||||
),
|
),
|
||||||
|
AutoRoutedApiInfo(
|
||||||
|
routing_table_api=Api.vector_dbs,
|
||||||
|
router_api=Api.vector_io,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.core.client import get_client_impl
|
from llama_stack.core.client import get_client_impl
|
||||||
|
|
@ -80,6 +81,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
Api.batches: Batches,
|
Api.batches: Batches,
|
||||||
Api.vector_io: VectorIO,
|
Api.vector_io: VectorIO,
|
||||||
|
Api.vector_dbs: VectorDBs,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ async def get_routing_table_impl(
|
||||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from ..routing_tables.shields import ShieldsRoutingTable
|
from ..routing_tables.shields import ShieldsRoutingTable
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
|
@ -34,6 +35,7 @@ async def get_routing_table_impl(
|
||||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||||
"benchmarks": BenchmarksRoutingTable,
|
"benchmarks": BenchmarksRoutingTable,
|
||||||
"tool_groups": ToolGroupsRoutingTable,
|
"tool_groups": ToolGroupsRoutingTable,
|
||||||
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
}
|
}
|
||||||
|
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
|
|
@ -44,7 +43,7 @@ class VectorIORouter(VectorIO):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
vector_stores_config=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing VectorIORouter")
|
logger.debug("Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
@ -125,9 +124,9 @@ class VectorIORouter(VectorIO):
|
||||||
embedding_dimension = extra.get("embedding_dimension")
|
embedding_dimension = extra.get("embedding_dimension")
|
||||||
provider_id = extra.get("provider_id")
|
provider_id = extra.get("provider_id")
|
||||||
|
|
||||||
|
# Use default embedding model if not specified
|
||||||
if embedding_model is None and self.vector_stores_config is not None:
|
if embedding_model is None and self.vector_stores_config is not None:
|
||||||
embedding_model = self.vector_stores_config.default_embedding_model_id
|
embedding_model = self.vector_stores_config.embedding_model_id
|
||||||
logger.debug(f"Using default embedding model: {embedding_model}")
|
|
||||||
|
|
||||||
if embedding_model is not None and embedding_dimension is None:
|
if embedding_model is not None and embedding_dimension is None:
|
||||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||||
|
|
@ -139,11 +138,24 @@ class VectorIORouter(VectorIO):
|
||||||
raise ValueError("No vector_io providers available")
|
raise ValueError("No vector_io providers available")
|
||||||
if num_providers > 1:
|
if num_providers > 1:
|
||||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||||
raise ValueError(
|
# Use default configured provider
|
||||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
if self.vector_stores_config and self.vector_stores_config.provider_id:
|
||||||
f"Available providers: {available_providers}"
|
default_provider = self.vector_stores_config.provider_id
|
||||||
)
|
if default_provider in available_providers:
|
||||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
provider_id = default_provider
|
||||||
|
logger.debug(f"Using configured default vector store provider: {provider_id}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Configured default vector store provider '{default_provider}' not found. "
|
||||||
|
f"Available providers: {available_providers}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||||
|
f"Available providers: {available_providers}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||||
registered_vector_db = await self.routing_table.register_vector_db(
|
registered_vector_db = await self.routing_table.register_vector_db(
|
||||||
|
|
@ -250,8 +262,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreDeleteResponse:
|
) -> VectorStoreDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||||
return await provider.openai_delete_vector_store(vector_store_id)
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
async def openai_search_vector_store(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from .shields import ShieldsRoutingTable
|
from .shields import ShieldsRoutingTable
|
||||||
from .toolgroups import ToolGroupsRoutingTable
|
from .toolgroups import ToolGroupsRoutingTable
|
||||||
|
from .vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
return ("Safety", "shield")
|
return ("Safety", "shield")
|
||||||
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
return ("VectorIO", "vector_db")
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
return ("DatasetIO", "dataset")
|
return ("DatasetIO", "dataset")
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
|
|
||||||
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
|
|
@ -0,0 +1,323 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
|
||||||
|
# Removed VectorDBs import to avoid exposing public API
|
||||||
|
from llama_stack.apis.vector_io.vector_io import (
|
||||||
|
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||||
|
SearchRankingOptions,
|
||||||
|
VectorStoreChunkingStrategy,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreFileContentsResponse,
|
||||||
|
VectorStoreFileDeleteResponse,
|
||||||
|
VectorStoreFileObject,
|
||||||
|
VectorStoreFileStatus,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import (
|
||||||
|
VectorDBWithOwner,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
|
"""Internal routing table for vector_db operations.
|
||||||
|
|
||||||
|
Does not inherit from VectorDBs to avoid exposing public API endpoints.
|
||||||
|
Only provides internal routing functionality for VectorIORouter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Internal methods only - no public API exposure
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
vector_db_name: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
if len(self.impls_by_provider_id) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
|
model = await lookup_model(self, embedding_model)
|
||||||
|
if model is None:
|
||||||
|
raise ModelNotFoundError(embedding_model)
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = self.impls_by_provider_id[provider_id]
|
||||||
|
except KeyError:
|
||||||
|
available_providers = list(self.impls_by_provider_id.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
||||||
|
) from None
|
||||||
|
logger.warning(
|
||||||
|
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||||
|
)
|
||||||
|
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||||
|
name=vector_db_name or vector_db_id,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_dimension=model.metadata["embedding_dimension"],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_vector_db_id=provider_vector_db_id,
|
||||||
|
)
|
||||||
|
vector_store = await provider.openai_create_vector_store(request)
|
||||||
|
|
||||||
|
vector_store_id = vector_store.id
|
||||||
|
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||||
|
logger.warning(
|
||||||
|
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_data = {
|
||||||
|
"identifier": vector_store_id,
|
||||||
|
"type": ResourceType.vector_db.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": actual_provider_vector_db_id,
|
||||||
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
|
"vector_db_name": vector_store.name,
|
||||||
|
}
|
||||||
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
|
await self.register_object(vector_db)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||||
|
|
||||||
|
async def openai_update_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_update_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
name=name,
|
||||||
|
expires_after=expires_after,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreDeleteResponse:
|
||||||
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||||
|
await self.unregister_vector_db(vector_store_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_store_id: str) -> None:
|
||||||
|
"""Remove the vector store from the routing table registry."""
|
||||||
|
try:
|
||||||
|
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
|
||||||
|
if vector_db_obj:
|
||||||
|
await self.unregister_object(vector_db_obj)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but don't fail the operation
|
||||||
|
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||||
|
|
||||||
|
async def openai_search_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: str | list[str],
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
max_num_results: int | None = 10,
|
||||||
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
|
rewrite_query: bool | None = False,
|
||||||
|
search_mode: str | None = "vector",
|
||||||
|
) -> VectorStoreSearchResponsePage:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_search_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
query=query,
|
||||||
|
filters=filters,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
ranking_options=ranking_options,
|
||||||
|
rewrite_query=rewrite_query,
|
||||||
|
search_mode=search_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_attach_file_to_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any] | None = None,
|
||||||
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_attach_file_to_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files_in_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: str | None = "desc",
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
filter: VectorStoreFileStatus | None = None,
|
||||||
|
) -> list[VectorStoreFileObject]:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
limit=limit,
|
||||||
|
order=order,
|
||||||
|
after=after,
|
||||||
|
before=before,
|
||||||
|
filter=filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_retrieve_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file_contents(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileContentsResponse:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_retrieve_vector_store_file_contents(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_update_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any],
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_update_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileDeleteResponse:
|
||||||
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_delete_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_create_vector_store_file_batch(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_ids: list[str],
|
||||||
|
attributes: dict[str, Any] | None = None,
|
||||||
|
chunking_strategy: Any | None = None,
|
||||||
|
):
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_create_vector_store_file_batch(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_ids=file_ids,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file_batch(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
vector_store_id: str,
|
||||||
|
):
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_retrieve_vector_store_file_batch(
|
||||||
|
batch_id=batch_id,
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files_in_vector_store_file_batch(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
vector_store_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
filter: str | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: str | None = "desc",
|
||||||
|
):
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||||
|
batch_id=batch_id,
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
after=after,
|
||||||
|
before=before,
|
||||||
|
filter=filter,
|
||||||
|
limit=limit,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_cancel_vector_store_file_batch(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
vector_store_id: str,
|
||||||
|
):
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.openai_cancel_vector_store_file_batch(
|
||||||
|
batch_id=batch_id,
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
@ -135,7 +135,7 @@ async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[
|
||||||
return
|
return
|
||||||
|
|
||||||
vector_stores_config = run_config.vector_stores
|
vector_stores_config = run_config.vector_stores
|
||||||
default_model_id = vector_stores_config.default_embedding_model_id
|
default_model_id = vector_stores_config.embedding_model_id
|
||||||
|
|
||||||
if Api.models not in impls:
|
if Api.models not in impls:
|
||||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||||
|
|
|
||||||
|
|
@ -255,4 +255,4 @@ server:
|
||||||
telemetry:
|
telemetry:
|
||||||
enabled: true
|
enabled: true
|
||||||
vector_stores:
|
vector_stores:
|
||||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||||
|
|
|
||||||
|
|
@ -258,4 +258,4 @@ server:
|
||||||
telemetry:
|
telemetry:
|
||||||
enabled: true
|
enabled: true
|
||||||
vector_stores:
|
vector_stores:
|
||||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||||
|
|
|
||||||
|
|
@ -255,4 +255,4 @@ server:
|
||||||
telemetry:
|
telemetry:
|
||||||
enabled: true
|
enabled: true
|
||||||
vector_stores:
|
vector_stores:
|
||||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
default_shields=default_shields,
|
default_shields=default_shields,
|
||||||
vector_stores_config=VectorStoresConfig(
|
vector_stores_config=VectorStoresConfig(
|
||||||
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -317,3 +317,72 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||||
if p.is_relative_to(rp):
|
if p.is_relative_to(rp):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_vector_io_provider_ids(client):
|
||||||
|
"""Get all available vector_io provider IDs."""
|
||||||
|
providers = [p for p in client.providers.list() if p.api == "vector_io"]
|
||||||
|
return [p.provider_id for p in providers]
|
||||||
|
|
||||||
|
|
||||||
|
def vector_provider_wrapper(func):
|
||||||
|
"""Decorator to run a test against all available vector_io providers."""
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Get the vector_io_provider_id from the test arguments
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
bound_args = sig.bind(*args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
|
||||||
|
vector_io_provider_id = bound_args.arguments.get("vector_io_provider_id")
|
||||||
|
if not vector_io_provider_id:
|
||||||
|
pytest.skip("No vector_io_provider_id provided")
|
||||||
|
|
||||||
|
# Get client_with_models to check available providers
|
||||||
|
client_with_models = bound_args.arguments.get("client_with_models")
|
||||||
|
if client_with_models:
|
||||||
|
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||||
|
if vector_io_provider_id not in available_providers:
|
||||||
|
pytest.skip(f"Provider '{vector_io_provider_id}' not available. Available: {available_providers}")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# For replay tests, only use providers that are available in ci-tests environment
|
||||||
|
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") == "replay":
|
||||||
|
all_providers = ["faiss", "sqlite-vec"]
|
||||||
|
else:
|
||||||
|
# For live tests, try all providers (they'll skip if not available)
|
||||||
|
all_providers = [
|
||||||
|
"faiss",
|
||||||
|
"sqlite-vec",
|
||||||
|
"milvus",
|
||||||
|
"chromadb",
|
||||||
|
"pgvector",
|
||||||
|
"weaviate",
|
||||||
|
"qdrant",
|
||||||
|
]
|
||||||
|
|
||||||
|
return pytest.mark.parametrize("vector_io_provider_id", all_providers)(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_io_provider_id(request, client_with_models):
|
||||||
|
"""Fixture that provides a specific vector_io provider ID, skipping if not available."""
|
||||||
|
if hasattr(request, "param"):
|
||||||
|
requested_provider = request.param
|
||||||
|
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||||
|
|
||||||
|
if requested_provider not in available_providers:
|
||||||
|
pytest.skip(f"Provider '{requested_provider}' not available. Available: {available_providers}")
|
||||||
|
|
||||||
|
return requested_provider
|
||||||
|
else:
|
||||||
|
provider_ids = get_vector_io_provider_ids(client_with_models)
|
||||||
|
if not provider_ids:
|
||||||
|
pytest.skip("No vector_io providers available")
|
||||||
|
return provider_ids[0]
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session):
|
||||||
# --stack-config bypasses template so need this to set default embedding model
|
# --stack-config bypasses template so need this to set default embedding model
|
||||||
if "vector_io" in config and "inference" in config:
|
if "vector_io" in config and "inference" in config:
|
||||||
run_config.vector_stores = VectorStoresConfig(
|
run_config.vector_stores = VectorStoresConfig(
|
||||||
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||||
)
|
)
|
||||||
|
|
||||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from ..conftest import vector_provider_wrapper
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="vector_io")
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -133,8 +135,9 @@ def compat_client_with_empty_stores(compat_client):
|
||||||
clear_files()
|
clear_files()
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_create_vector_store(
|
def test_openai_create_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test creating a vector store using OpenAI API."""
|
"""Test creating a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -146,6 +149,7 @@ def test_openai_create_vector_store(
|
||||||
metadata={"purpose": "testing", "environment": "integration"},
|
metadata={"purpose": "testing", "environment": "integration"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -159,14 +163,18 @@ def test_openai_create_vector_store(
|
||||||
assert hasattr(vector_store, "created_at")
|
assert hasattr(vector_store, "created_at")
|
||||||
|
|
||||||
|
|
||||||
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models):
|
@vector_provider_wrapper
|
||||||
|
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models, vector_io_provider_id):
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
vector_store = compat_client_with_empty_stores.vector_stores.create()
|
vector_store = compat_client_with_empty_stores.vector_stores.create(
|
||||||
|
extra_body={"provider_id": vector_io_provider_id}
|
||||||
|
)
|
||||||
assert vector_store.id
|
assert vector_store.id
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_list_vector_stores(
|
def test_openai_list_vector_stores(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test listing vector stores using OpenAI API."""
|
"""Test listing vector stores using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -179,6 +187,7 @@ def test_openai_list_vector_stores(
|
||||||
metadata={"type": "test"},
|
metadata={"type": "test"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
store2 = client.vector_stores.create(
|
store2 = client.vector_stores.create(
|
||||||
|
|
@ -186,6 +195,7 @@ def test_openai_list_vector_stores(
|
||||||
metadata={"type": "test"},
|
metadata={"type": "test"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -206,8 +216,9 @@ def test_openai_list_vector_stores(
|
||||||
assert len(limited_response.data) == 1
|
assert len(limited_response.data) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_retrieve_vector_store(
|
def test_openai_retrieve_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test retrieving a specific vector store using OpenAI API."""
|
"""Test retrieving a specific vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -220,6 +231,7 @@ def test_openai_retrieve_vector_store(
|
||||||
metadata={"purpose": "retrieval_test"},
|
metadata={"purpose": "retrieval_test"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -233,8 +245,9 @@ def test_openai_retrieve_vector_store(
|
||||||
assert retrieved_store.object == "vector_store"
|
assert retrieved_store.object == "vector_store"
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_update_vector_store(
|
def test_openai_update_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test modifying a vector store using OpenAI API."""
|
"""Test modifying a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -247,6 +260,7 @@ def test_openai_update_vector_store(
|
||||||
metadata={"version": "1.0"},
|
metadata={"version": "1.0"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
@ -264,8 +278,9 @@ def test_openai_update_vector_store(
|
||||||
assert modified_store.last_active_at > created_store.last_active_at
|
assert modified_store.last_active_at > created_store.last_active_at
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_delete_vector_store(
|
def test_openai_delete_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test deleting a vector store using OpenAI API."""
|
"""Test deleting a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -278,6 +293,7 @@ def test_openai_delete_vector_store(
|
||||||
metadata={"purpose": "deletion_test"},
|
metadata={"purpose": "deletion_test"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -294,8 +310,9 @@ def test_openai_delete_vector_store(
|
||||||
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_empty(
|
def test_openai_vector_store_search_empty(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test searching an empty vector store using OpenAI API."""
|
"""Test searching an empty vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -308,6 +325,7 @@ def test_openai_vector_store_search_empty(
|
||||||
metadata={"purpose": "search_testing"},
|
metadata={"purpose": "search_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -323,8 +341,14 @@ def test_openai_vector_store_search_empty(
|
||||||
assert search_response.has_more is False
|
assert search_response.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_with_chunks(
|
def test_openai_vector_store_with_chunks(
|
||||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores,
|
||||||
|
client_with_models,
|
||||||
|
sample_chunks,
|
||||||
|
embedding_model_id,
|
||||||
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -338,6 +362,7 @@ def test_openai_vector_store_with_chunks(
|
||||||
metadata={"purpose": "chunks_testing"},
|
metadata={"purpose": "chunks_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -380,6 +405,7 @@ def test_openai_vector_store_with_chunks(
|
||||||
("What inspires neural networks?", "doc4", "ai"),
|
("What inspires neural networks?", "doc4", "ai"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_relevance(
|
def test_openai_vector_store_search_relevance(
|
||||||
compat_client_with_empty_stores,
|
compat_client_with_empty_stores,
|
||||||
client_with_models,
|
client_with_models,
|
||||||
|
|
@ -387,6 +413,7 @@ def test_openai_vector_store_search_relevance(
|
||||||
test_case,
|
test_case,
|
||||||
embedding_model_id,
|
embedding_model_id,
|
||||||
embedding_dimension,
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -402,6 +429,7 @@ def test_openai_vector_store_search_relevance(
|
||||||
metadata={"purpose": "relevance_testing"},
|
metadata={"purpose": "relevance_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -430,8 +458,14 @@ def test_openai_vector_store_search_relevance(
|
||||||
assert top_result.score > 0
|
assert top_result.score > 0
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_with_ranking_options(
|
def test_openai_vector_store_search_with_ranking_options(
|
||||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores,
|
||||||
|
client_with_models,
|
||||||
|
sample_chunks,
|
||||||
|
embedding_model_id,
|
||||||
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with ranking options."""
|
"""Test OpenAI vector store search with ranking options."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -445,6 +479,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
metadata={"purpose": "ranking_testing"},
|
metadata={"purpose": "ranking_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -483,8 +518,14 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
assert result.score >= threshold
|
assert result.score >= threshold
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_with_high_score_filter(
|
def test_openai_vector_store_search_with_high_score_filter(
|
||||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores,
|
||||||
|
client_with_models,
|
||||||
|
sample_chunks,
|
||||||
|
embedding_model_id,
|
||||||
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -498,6 +539,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
metadata={"purpose": "high_score_filtering"},
|
metadata={"purpose": "high_score_filtering"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -542,8 +584,14 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_with_max_num_results(
|
def test_openai_vector_store_search_with_max_num_results(
|
||||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores,
|
||||||
|
client_with_models,
|
||||||
|
sample_chunks,
|
||||||
|
embedding_model_id,
|
||||||
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with max_num_results."""
|
"""Test OpenAI vector store search with max_num_results."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -557,6 +605,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
metadata={"purpose": "max_num_results_testing"},
|
metadata={"purpose": "max_num_results_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -577,8 +626,9 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
assert len(search_response.data) == 2
|
assert len(search_response.data) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_attach_file(
|
def test_openai_vector_store_attach_file(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store attach file."""
|
"""Test OpenAI vector store attach file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -591,6 +641,7 @@ def test_openai_vector_store_attach_file(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -637,8 +688,9 @@ def test_openai_vector_store_attach_file(
|
||||||
assert "foobazbar" in top_content.lower()
|
assert "foobazbar" in top_content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_attach_files_on_creation(
|
def test_openai_vector_store_attach_files_on_creation(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store attach files on creation."""
|
"""Test OpenAI vector store attach files on creation."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -668,6 +720,7 @@ def test_openai_vector_store_attach_files_on_creation(
|
||||||
file_ids=file_ids,
|
file_ids=file_ids,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -700,8 +753,9 @@ def test_openai_vector_store_attach_files_on_creation(
|
||||||
assert updated_vector_store.file_counts.failed == 0
|
assert updated_vector_store.file_counts.failed == 0
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_list_files(
|
def test_openai_vector_store_list_files(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store list files."""
|
"""Test OpenAI vector store list files."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -714,6 +768,7 @@ def test_openai_vector_store_list_files(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -773,8 +828,9 @@ def test_openai_vector_store_list_files(
|
||||||
assert updated_vector_store.file_counts.in_progress == 0
|
assert updated_vector_store.file_counts.in_progress == 0
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_list_files_invalid_vector_store(
|
def test_openai_vector_store_list_files_invalid_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store list files with invalid vector store ID."""
|
"""Test OpenAI vector store list files with invalid vector store ID."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -789,8 +845,9 @@ def test_openai_vector_store_list_files_invalid_vector_store(
|
||||||
compat_client.vector_stores.files.list(vector_store_id="abc123")
|
compat_client.vector_stores.files.list(vector_store_id="abc123")
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_retrieve_file_contents(
|
def test_openai_vector_store_retrieve_file_contents(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store retrieve file contents."""
|
"""Test OpenAI vector store retrieve file contents."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -803,6 +860,7 @@ def test_openai_vector_store_retrieve_file_contents(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -848,8 +906,9 @@ def test_openai_vector_store_retrieve_file_contents(
|
||||||
assert file_contents.attributes == attributes
|
assert file_contents.attributes == attributes
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_delete_file(
|
def test_openai_vector_store_delete_file(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store delete file."""
|
"""Test OpenAI vector store delete file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -862,6 +921,7 @@ def test_openai_vector_store_delete_file(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -912,8 +972,9 @@ def test_openai_vector_store_delete_file(
|
||||||
assert updated_vector_store.file_counts.in_progress == 0
|
assert updated_vector_store.file_counts.in_progress == 0
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_delete_file_removes_from_vector_store(
|
def test_openai_vector_store_delete_file_removes_from_vector_store(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store delete file removes from vector store."""
|
"""Test OpenAI vector store delete file removes from vector store."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -926,6 +987,7 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -962,8 +1024,9 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
|
||||||
assert not search_response.data
|
assert not search_response.data
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_update_file(
|
def test_openai_vector_store_update_file(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store update file."""
|
"""Test OpenAI vector store update file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -976,6 +1039,7 @@ def test_openai_vector_store_update_file(
|
||||||
name="test_store",
|
name="test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1017,8 +1081,9 @@ def test_openai_vector_store_update_file(
|
||||||
assert retrieved_file.attributes["foo"] == "baz"
|
assert retrieved_file.attributes["foo"] == "baz"
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_create_vector_store_files_duplicate_vector_store_name(
|
def test_create_vector_store_files_duplicate_vector_store_name(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This test confirms that client.vector_stores.create() creates a unique ID
|
This test confirms that client.vector_stores.create() creates a unique ID
|
||||||
|
|
@ -1044,6 +1109,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
||||||
name="test_store_with_files",
|
name="test_store_with_files",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert vector_store.file_counts.completed == 0
|
assert vector_store.file_counts.completed == 0
|
||||||
|
|
@ -1056,6 +1122,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
||||||
name="test_store_with_files",
|
name="test_store_with_files",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1086,8 +1153,15 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("search_mode", ["vector", "keyword", "hybrid"])
|
@pytest.mark.parametrize("search_mode", ["vector", "keyword", "hybrid"])
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_search_modes(
|
def test_openai_vector_store_search_modes(
|
||||||
llama_stack_client, client_with_models, sample_chunks, search_mode, embedding_model_id, embedding_dimension
|
llama_stack_client,
|
||||||
|
client_with_models,
|
||||||
|
sample_chunks,
|
||||||
|
search_mode,
|
||||||
|
embedding_model_id,
|
||||||
|
embedding_dimension,
|
||||||
|
vector_io_provider_id,
|
||||||
):
|
):
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
|
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
|
||||||
|
|
@ -1097,6 +1171,7 @@ def test_openai_vector_store_search_modes(
|
||||||
metadata={"purpose": "search_mode_testing"},
|
metadata={"purpose": "search_mode_testing"},
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1115,8 +1190,9 @@ def test_openai_vector_store_search_modes(
|
||||||
assert search_response is not None
|
assert search_response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_file_batch_create_and_retrieve(
|
def test_openai_vector_store_file_batch_create_and_retrieve(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test creating and retrieving a vector store file batch."""
|
"""Test creating and retrieving a vector store file batch."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1128,6 +1204,7 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
||||||
name="batch_test_store",
|
name="batch_test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1178,8 +1255,9 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
||||||
assert retrieved_batch.status == "completed" # Should be completed after processing
|
assert retrieved_batch.status == "completed" # Should be completed after processing
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_file_batch_list_files(
|
def test_openai_vector_store_file_batch_list_files(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test listing files in a vector store file batch."""
|
"""Test listing files in a vector store file batch."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1191,6 +1269,7 @@ def test_openai_vector_store_file_batch_list_files(
|
||||||
name="batch_list_test_store",
|
name="batch_list_test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1271,8 +1350,9 @@ def test_openai_vector_store_file_batch_list_files(
|
||||||
assert first_page_ids.isdisjoint(second_page_ids)
|
assert first_page_ids.isdisjoint(second_page_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_file_batch_cancel(
|
def test_openai_vector_store_file_batch_cancel(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test cancelling a vector store file batch."""
|
"""Test cancelling a vector store file batch."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1284,6 +1364,7 @@ def test_openai_vector_store_file_batch_cancel(
|
||||||
name="batch_cancel_test_store",
|
name="batch_cancel_test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1326,8 +1407,9 @@ def test_openai_vector_store_file_batch_cancel(
|
||||||
assert final_batch.status in ["completed", "cancelled"]
|
assert final_batch.status in ["completed", "cancelled"]
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_file_batch_retrieve_contents(
|
def test_openai_vector_store_file_batch_retrieve_contents(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test retrieving file contents after file batch processing."""
|
"""Test retrieving file contents after file batch processing."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1339,6 +1421,7 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
||||||
name="batch_contents_test_store",
|
name="batch_contents_test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1399,8 +1482,9 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
||||||
assert file_data[i][1].decode("utf-8") in content_text
|
assert file_data[i][1].decode("utf-8") in content_text
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_file_batch_error_handling(
|
def test_openai_vector_store_file_batch_error_handling(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test error handling for file batch operations."""
|
"""Test error handling for file batch operations."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1412,6 +1496,7 @@ def test_openai_vector_store_file_batch_error_handling(
|
||||||
name="batch_error_test_store",
|
name="batch_error_test_store",
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1456,8 +1541,9 @@ def test_openai_vector_store_file_batch_error_handling(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_openai_vector_store_embedding_config_from_metadata(
|
def test_openai_vector_store_embedding_config_from_metadata(
|
||||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
"""Test that embedding configuration works from metadata source."""
|
"""Test that embedding configuration works from metadata source."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
@ -1471,6 +1557,9 @@ def test_openai_vector_store_embedding_config_from_metadata(
|
||||||
"embedding_dimension": str(embedding_dimension),
|
"embedding_dimension": str(embedding_dimension),
|
||||||
"test_source": "metadata",
|
"test_source": "metadata",
|
||||||
},
|
},
|
||||||
|
extra_body={
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert vector_store_metadata is not None
|
assert vector_store_metadata is not None
|
||||||
|
|
@ -1489,6 +1578,7 @@ def test_openai_vector_store_embedding_config_from_metadata(
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
"embedding_dimension": int(embedding_dimension), # Ensure same type/value
|
"embedding_dimension": int(embedding_dimension), # Ensure same type/value
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
|
|
||||||
|
from ..conftest import vector_provider_wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
|
|
@ -46,12 +48,13 @@ def client_with_empty_registry(client_with_models):
|
||||||
clear_registry()
|
clear_registry()
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
@vector_provider_wrapper
|
||||||
|
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||||
vector_db_name = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
create_response = client_with_empty_registry.vector_stores.create(
|
create_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_db_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -65,12 +68,13 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe
|
||||||
assert response.id.startswith("vs_")
|
assert response.id.startswith("vs_")
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
@vector_provider_wrapper
|
||||||
|
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||||
vector_db_name = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
response = client_with_empty_registry.vector_stores.create(
|
response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_db_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -100,12 +104,15 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
||||||
("How does machine learning improve over time?", "doc2"),
|
("How does machine learning improve over time?", "doc2"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
@vector_provider_wrapper
|
||||||
|
def test_insert_chunks(
|
||||||
|
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
|
||||||
|
):
|
||||||
vector_db_name = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
create_response = client_with_empty_registry.vector_stores.create(
|
create_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_db_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -135,7 +142,10 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
||||||
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
||||||
|
|
||||||
|
|
||||||
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
@vector_provider_wrapper
|
||||||
|
def test_insert_chunks_with_precomputed_embeddings(
|
||||||
|
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
|
):
|
||||||
vector_io_provider_params_dict = {
|
vector_io_provider_params_dict = {
|
||||||
"inline::milvus": {"score_threshold": -1.0},
|
"inline::milvus": {"score_threshold": -1.0},
|
||||||
"inline::qdrant": {"score_threshold": -1.0},
|
"inline::qdrant": {"score_threshold": -1.0},
|
||||||
|
|
@ -145,7 +155,7 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
register_response = client_with_empty_registry.vector_stores.create(
|
register_response = client_with_empty_registry.vector_stores.create(
|
||||||
name=vector_db_name,
|
name=vector_db_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -181,8 +191,9 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
|
|
||||||
|
|
||||||
# expect this test to fail
|
# expect this test to fail
|
||||||
|
@vector_provider_wrapper
|
||||||
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
client_with_empty_registry, embedding_model_id, embedding_dimension
|
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
):
|
):
|
||||||
vector_io_provider_params_dict = {
|
vector_io_provider_params_dict = {
|
||||||
"inline::milvus": {"score_threshold": 0.0},
|
"inline::milvus": {"score_threshold": 0.0},
|
||||||
|
|
@ -194,6 +205,7 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
name=vector_db_name,
|
name=vector_db_name,
|
||||||
extra_body={
|
extra_body={
|
||||||
"embedding_model": embedding_model_id,
|
"embedding_model": embedding_model_id,
|
||||||
|
"provider_id": vector_io_provider_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -226,33 +238,44 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||||
|
|
||||||
|
|
||||||
def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id):
|
@vector_provider_wrapper
|
||||||
|
def test_auto_extract_embedding_dimension(
|
||||||
|
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
|
):
|
||||||
|
# This test specifically tests embedding model override, so we keep embedding_model
|
||||||
vs = client_with_empty_registry.vector_stores.create(
|
vs = client_with_empty_registry.vector_stores.create(
|
||||||
name="test_auto_extract", extra_body={"embedding_model": embedding_model_id}
|
name="test_auto_extract",
|
||||||
|
extra_body={"embedding_model": embedding_model_id, "provider_id": vector_io_provider_id},
|
||||||
)
|
)
|
||||||
assert vs.id is not None
|
assert vs.id is not None
|
||||||
|
|
||||||
|
|
||||||
def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id):
|
@vector_provider_wrapper
|
||||||
|
def test_provider_auto_selection_single_provider(
|
||||||
|
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
|
):
|
||||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||||
if len(providers) != 1:
|
if len(providers) != 1:
|
||||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||||
|
|
||||||
vs = client_with_empty_registry.vector_stores.create(
|
# Test that when only one provider is available, it's auto-selected (no provider_id needed)
|
||||||
name="test_auto_provider", extra_body={"embedding_model": embedding_model_id}
|
vs = client_with_empty_registry.vector_stores.create(name="test_auto_provider")
|
||||||
)
|
|
||||||
assert vs.id is not None
|
assert vs.id is not None
|
||||||
|
|
||||||
|
|
||||||
def test_provider_id_override(client_with_empty_registry, embedding_model_id):
|
@vector_provider_wrapper
|
||||||
|
def test_provider_id_override(
|
||||||
|
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||||
|
):
|
||||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||||
if len(providers) != 1:
|
if len(providers) != 1:
|
||||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||||
|
|
||||||
provider_id = providers[0].provider_id
|
provider_id = providers[0].provider_id
|
||||||
|
|
||||||
|
# Test explicit provider_id specification (using default embedding model)
|
||||||
vs = client_with_empty_registry.vector_stores.create(
|
vs = client_with_empty_registry.vector_stores.create(
|
||||||
name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id}
|
name="test_provider_override", extra_body={"provider_id": provider_id}
|
||||||
)
|
)
|
||||||
assert vs.id is not None
|
assert vs.id is not None
|
||||||
assert vs.metadata.get("provider_id") == provider_id
|
assert vs.metadata.get("provider_id") == provider_id
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class TestVectorStoresValidation:
|
||||||
async def test_validate_missing_model(self):
|
async def test_validate_missing_model(self):
|
||||||
"""Test validation fails when model not found."""
|
"""Test validation fails when model not found."""
|
||||||
run_config = StackRunConfig(
|
run_config = StackRunConfig(
|
||||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="missing")
|
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="missing")
|
||||||
)
|
)
|
||||||
mock_models = AsyncMock()
|
mock_models = AsyncMock()
|
||||||
mock_models.list_models.return_value = []
|
mock_models.list_models.return_value = []
|
||||||
|
|
@ -31,7 +31,7 @@ class TestVectorStoresValidation:
|
||||||
async def test_validate_success(self):
|
async def test_validate_success(self):
|
||||||
"""Test validation passes with valid model."""
|
"""Test validation passes with valid model."""
|
||||||
run_config = StackRunConfig(
|
run_config = StackRunConfig(
|
||||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="valid")
|
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="valid")
|
||||||
)
|
)
|
||||||
mock_models = AsyncMock()
|
mock_models = AsyncMock()
|
||||||
mock_models.list_models.return_value = [
|
mock_models.list_models.return_value = [
|
||||||
|
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
|
||||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
|
||||||
|
|
||||||
|
|
||||||
def test_filters_single_api():
|
|
||||||
"""Test filtering keeps only specified provider for one API."""
|
|
||||||
build_config = BuildConfig(
|
|
||||||
image_type=LlamaStackImageType.VENV.value,
|
|
||||||
distribution_spec=DistributionSpec(
|
|
||||||
providers={
|
|
||||||
"vector_io": [
|
|
||||||
BuildProvider(provider_type="inline::faiss"),
|
|
||||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
|
||||||
],
|
|
||||||
"inference": [
|
|
||||||
BuildProvider(provider_type="remote::openai"),
|
|
||||||
],
|
|
||||||
},
|
|
||||||
description="Test",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
|
||||||
|
|
||||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
|
||||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
|
||||||
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
|
||||||
|
|
||||||
|
|
||||||
def test_filters_multiple_apis():
|
|
||||||
"""Test filtering multiple APIs."""
|
|
||||||
build_config = BuildConfig(
|
|
||||||
image_type=LlamaStackImageType.VENV.value,
|
|
||||||
distribution_spec=DistributionSpec(
|
|
||||||
providers={
|
|
||||||
"vector_io": [
|
|
||||||
BuildProvider(provider_type="inline::faiss"),
|
|
||||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
|
||||||
],
|
|
||||||
"inference": [
|
|
||||||
BuildProvider(provider_type="remote::openai"),
|
|
||||||
BuildProvider(provider_type="remote::anthropic"),
|
|
||||||
],
|
|
||||||
},
|
|
||||||
description="Test",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
|
||||||
|
|
||||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
|
||||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
|
||||||
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
|
||||||
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
|
||||||
|
|
||||||
|
|
||||||
def test_provider_not_found_exits():
|
|
||||||
"""Test error when specified provider doesn't exist."""
|
|
||||||
build_config = BuildConfig(
|
|
||||||
image_type=LlamaStackImageType.VENV.value,
|
|
||||||
distribution_spec=DistributionSpec(
|
|
||||||
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
|
||||||
description="Test",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_format_exits():
|
|
||||||
"""Test error for invalid filter format."""
|
|
||||||
build_config = BuildConfig(
|
|
||||||
image_type=LlamaStackImageType.VENV.value,
|
|
||||||
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
_apply_single_provider_filter(build_config, "invalid_format")
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue