mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
adding back relevant vector_db files
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating tests and fixing routing logic for single provider Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> setting default provider to update tests Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updated provider_id Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updated VectorStoreConfig to use (provider_id, embedding_model_id) and add defautl vector store provider Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
accc4c437e
commit
b3addc94d1
23 changed files with 637 additions and 261 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
|||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
||||
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
||||
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
|
|
@ -168,7 +168,7 @@ jobs:
|
|||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
run: |
|
||||
uv run --no-sync \
|
||||
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
tests/integration/vector_io
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_dbs = "vector_dbs" # only used for routing
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -59,3 +59,35 @@ class ListVectorDBsResponse(BaseModel):
|
|||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VectorDBs(Protocol):
|
||||
"""Internal protocol for vector_dbs routing - no public API endpoints."""
|
||||
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""Internal method to list vector databases."""
|
||||
...
|
||||
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB:
|
||||
"""Internal method to get a vector database by ID."""
|
||||
...
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Internal method to register a vector database."""
|
||||
...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Internal method to unregister a vector database."""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -50,85 +50,6 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||
# Parse the single-provider argument using the same logic as --providers
|
||||
provider_filters: dict[str, str] = {}
|
||||
for api_provider in single_provider_arg.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
provider_filters[api] = provider_type
|
||||
|
||||
# Create a copy of the build config to modify
|
||||
filtered_build_config = BuildConfig(
|
||||
image_type=build_config.image_type,
|
||||
image_name=build_config.image_name,
|
||||
external_providers_dir=build_config.external_providers_dir,
|
||||
external_apis_dir=build_config.external_apis_dir,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={},
|
||||
description=build_config.distribution_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
# Copy all providers, but filter the specified APIs
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in provider_filters:
|
||||
target_provider_type = provider_filters[api]
|
||||
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||
if not filtered_providers:
|
||||
cprint(
|
||||
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||
else:
|
||||
# Keep all providers for unfiltered APIs
|
||||
filtered_build_config.distribution_spec.providers[api] = providers
|
||||
|
||||
return filtered_build_config
|
||||
|
||||
|
||||
def _generate_filtered_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
distro_name: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||
and filtering the providers according to the build_config.
|
||||
"""
|
||||
# Load the original distribution's run.yaml
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
with open(path) as f:
|
||||
original_config = yaml.safe_load(f)
|
||||
|
||||
# Apply provider filtering to the loaded config
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in original_config.get("providers", {}):
|
||||
# Filter this API to only include the providers from build_config
|
||||
provider_types = {p.provider_type for p in providers}
|
||||
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||
original_config["providers"][api] = filtered_providers
|
||||
|
||||
# Write the filtered config
|
||||
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||
with open(run_config_file, "w") as f:
|
||||
yaml.dump(original_config, f, sort_keys=False)
|
||||
|
||||
return run_config_file
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
|
@ -172,11 +93,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
build_config = available_distros[distro_name]
|
||||
|
||||
# Apply single-provider filtering if specified
|
||||
if args.single_provider:
|
||||
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
|
|
@ -329,7 +245,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
distro_name=distro_name,
|
||||
is_filtered=bool(args.single_provider),
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
|
|
@ -448,7 +363,6 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
distro_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
is_filtered: bool = False,
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
@ -521,19 +435,12 @@ def _run_stack_build_command_from_build_config(
|
|||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if distro_name:
|
||||
# If single-provider filtering was applied, generate a filtered run config
|
||||
# Otherwise, copy run.yaml from distribution as before
|
||||
if is_filtered:
|
||||
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||
distro_path = run_config_file # Use the generated file as the distro_path
|
||||
else:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -92,13 +92,6 @@ the build. If not specified, currently active environment will be used if found.
|
|||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--single-provider",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
|
|
|
|||
|
|
@ -354,10 +354,14 @@ class AuthenticationRequiredError(Exception):
|
|||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
||||
default_embedding_model_id: str = Field(
|
||||
embedding_model_id: str = Field(
|
||||
...,
|
||||
description="ID of the embedding model to use as default for vector stores when none is specified. Must reference a model defined in the 'models' section.",
|
||||
)
|
||||
provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
|
|
|
|||
|
|
@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_dbs,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
|
|
@ -80,6 +81,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"models": ModelsRoutingTable,
|
||||
|
|
@ -34,6 +35,7 @@ async def get_routing_table_impl(
|
|||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
|
|
@ -44,7 +43,7 @@ class VectorIORouter(VectorIO):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
vector_stores_config=None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
|
@ -125,9 +124,9 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# Use default embedding model if not specified
|
||||
if embedding_model is None and self.vector_stores_config is not None:
|
||||
embedding_model = self.vector_stores_config.default_embedding_model_id
|
||||
logger.debug(f"Using default embedding model: {embedding_model}")
|
||||
embedding_model = self.vector_stores_config.embedding_model_id
|
||||
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
|
@ -139,10 +138,23 @@ class VectorIORouter(VectorIO):
|
|||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
# Use default configured provider
|
||||
if self.vector_stores_config and self.vector_stores_config.provider_id:
|
||||
default_provider = self.vector_stores_config.provider_id
|
||||
if default_provider in available_providers:
|
||||
provider_id = default_provider
|
||||
logger.debug(f"Using configured default vector store provider: {provider_id}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Configured default vector store provider '{default_provider}' not found. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
|
|
@ -250,8 +262,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
|||
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
323
llama_stack/core/routing_tables/vector_dbs.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
||||
# Removed VectorDBs import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_db operations.
|
||||
|
||||
Does not inherit from VectorDBs to avoid exposing public API endpoints.
|
||||
Only provides internal routing functionality for VectorIORouter.
|
||||
"""
|
||||
|
||||
# Internal methods only - no public API exposure
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> Any:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
try:
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
except KeyError:
|
||||
available_providers = list(self.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
||||
) from None
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(request)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def unregister_vector_db(self, vector_store_id: str) -> None:
|
||||
"""Remove the vector store from the routing table registry."""
|
||||
try:
|
||||
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
|
||||
if vector_db_obj:
|
||||
await self.unregister_object(vector_db_obj)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the operation
|
||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
@ -135,7 +135,7 @@ async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[
|
|||
return
|
||||
|
||||
vector_stores_config = run_config.vector_stores
|
||||
default_model_id = vector_stores_config.default_embedding_model_id
|
||||
default_model_id = vector_stores_config.embedding_model_id
|
||||
|
||||
if Api.models not in impls:
|
||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||
|
|
|
|||
|
|
@ -255,4 +255,4 @@ server:
|
|||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -258,4 +258,4 @@ server:
|
|||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -255,4 +255,4 @@ server:
|
|||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
),
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -317,3 +317,72 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
|||
if p.is_relative_to(rp):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_vector_io_provider_ids(client):
|
||||
"""Get all available vector_io provider IDs."""
|
||||
providers = [p for p in client.providers.list() if p.api == "vector_io"]
|
||||
return [p.provider_id for p in providers]
|
||||
|
||||
|
||||
def vector_provider_wrapper(func):
|
||||
"""Decorator to run a test against all available vector_io providers."""
|
||||
import functools
|
||||
import os
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the vector_io_provider_id from the test arguments
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
vector_io_provider_id = bound_args.arguments.get("vector_io_provider_id")
|
||||
if not vector_io_provider_id:
|
||||
pytest.skip("No vector_io_provider_id provided")
|
||||
|
||||
# Get client_with_models to check available providers
|
||||
client_with_models = bound_args.arguments.get("client_with_models")
|
||||
if client_with_models:
|
||||
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||
if vector_io_provider_id not in available_providers:
|
||||
pytest.skip(f"Provider '{vector_io_provider_id}' not available. Available: {available_providers}")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# For replay tests, only use providers that are available in ci-tests environment
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") == "replay":
|
||||
all_providers = ["faiss", "sqlite-vec"]
|
||||
else:
|
||||
# For live tests, try all providers (they'll skip if not available)
|
||||
all_providers = [
|
||||
"faiss",
|
||||
"sqlite-vec",
|
||||
"milvus",
|
||||
"chromadb",
|
||||
"pgvector",
|
||||
"weaviate",
|
||||
"qdrant",
|
||||
]
|
||||
|
||||
return pytest.mark.parametrize("vector_io_provider_id", all_providers)(wrapper)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_provider_id(request, client_with_models):
|
||||
"""Fixture that provides a specific vector_io provider ID, skipping if not available."""
|
||||
if hasattr(request, "param"):
|
||||
requested_provider = request.param
|
||||
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||
|
||||
if requested_provider not in available_providers:
|
||||
pytest.skip(f"Provider '{requested_provider}' not available. Available: {available_providers}")
|
||||
|
||||
return requested_provider
|
||||
else:
|
||||
provider_ids = get_vector_io_provider_ids(client_with_models)
|
||||
if not provider_ids:
|
||||
pytest.skip("No vector_io providers available")
|
||||
return provider_ids[0]
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session):
|
|||
# --stack-config bypasses template so need this to set default embedding model
|
||||
if "vector_io" in config and "inference" in config:
|
||||
run_config.vector_stores = VectorStoresConfig(
|
||||
default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5"
|
||||
)
|
||||
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from llama_stack.apis.vector_io import Chunk
|
|||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..conftest import vector_provider_wrapper
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
||||
|
|
@ -133,8 +135,9 @@ def compat_client_with_empty_stores(compat_client):
|
|||
clear_files()
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_create_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test creating a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -146,6 +149,7 @@ def test_openai_create_vector_store(
|
|||
metadata={"purpose": "testing", "environment": "integration"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -159,14 +163,18 @@ def test_openai_create_vector_store(
|
|||
assert hasattr(vector_store, "created_at")
|
||||
|
||||
|
||||
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models):
|
||||
@vector_provider_wrapper
|
||||
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models, vector_io_provider_id):
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
vector_store = compat_client_with_empty_stores.vector_stores.create()
|
||||
vector_store = compat_client_with_empty_stores.vector_stores.create(
|
||||
extra_body={"provider_id": vector_io_provider_id}
|
||||
)
|
||||
assert vector_store.id
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_list_vector_stores(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test listing vector stores using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -179,6 +187,7 @@ def test_openai_list_vector_stores(
|
|||
metadata={"type": "test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
store2 = client.vector_stores.create(
|
||||
|
|
@ -186,6 +195,7 @@ def test_openai_list_vector_stores(
|
|||
metadata={"type": "test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -206,8 +216,9 @@ def test_openai_list_vector_stores(
|
|||
assert len(limited_response.data) == 1
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_retrieve_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test retrieving a specific vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -220,6 +231,7 @@ def test_openai_retrieve_vector_store(
|
|||
metadata={"purpose": "retrieval_test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -233,8 +245,9 @@ def test_openai_retrieve_vector_store(
|
|||
assert retrieved_store.object == "vector_store"
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_update_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test modifying a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -247,6 +260,7 @@ def test_openai_update_vector_store(
|
|||
metadata={"version": "1.0"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
time.sleep(1)
|
||||
|
|
@ -264,8 +278,9 @@ def test_openai_update_vector_store(
|
|||
assert modified_store.last_active_at > created_store.last_active_at
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_delete_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test deleting a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -278,6 +293,7 @@ def test_openai_delete_vector_store(
|
|||
metadata={"purpose": "deletion_test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -294,8 +310,9 @@ def test_openai_delete_vector_store(
|
|||
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_empty(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test searching an empty vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -308,6 +325,7 @@ def test_openai_vector_store_search_empty(
|
|||
metadata={"purpose": "search_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -323,8 +341,14 @@ def test_openai_vector_store_search_empty(
|
|||
assert search_response.has_more is False
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_with_chunks(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores,
|
||||
client_with_models,
|
||||
sample_chunks,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -338,6 +362,7 @@ def test_openai_vector_store_with_chunks(
|
|||
metadata={"purpose": "chunks_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -380,6 +405,7 @@ def test_openai_vector_store_with_chunks(
|
|||
("What inspires neural networks?", "doc4", "ai"),
|
||||
],
|
||||
)
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_relevance(
|
||||
compat_client_with_empty_stores,
|
||||
client_with_models,
|
||||
|
|
@ -387,6 +413,7 @@ def test_openai_vector_store_search_relevance(
|
|||
test_case,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -402,6 +429,7 @@ def test_openai_vector_store_search_relevance(
|
|||
metadata={"purpose": "relevance_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -430,8 +458,14 @@ def test_openai_vector_store_search_relevance(
|
|||
assert top_result.score > 0
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_with_ranking_options(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores,
|
||||
client_with_models,
|
||||
sample_chunks,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
"""Test OpenAI vector store search with ranking options."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -445,6 +479,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
|||
metadata={"purpose": "ranking_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -483,8 +518,14 @@ def test_openai_vector_store_search_with_ranking_options(
|
|||
assert result.score >= threshold
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_with_high_score_filter(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores,
|
||||
client_with_models,
|
||||
sample_chunks,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -498,6 +539,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
|||
metadata={"purpose": "high_score_filtering"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -542,8 +584,14 @@ def test_openai_vector_store_search_with_high_score_filter(
|
|||
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_with_max_num_results(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores,
|
||||
client_with_models,
|
||||
sample_chunks,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
"""Test OpenAI vector store search with max_num_results."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -557,6 +605,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
|||
metadata={"purpose": "max_num_results_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -577,8 +626,9 @@ def test_openai_vector_store_search_with_max_num_results(
|
|||
assert len(search_response.data) == 2
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_attach_file(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store attach file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -591,6 +641,7 @@ def test_openai_vector_store_attach_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -637,8 +688,9 @@ def test_openai_vector_store_attach_file(
|
|||
assert "foobazbar" in top_content.lower()
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_attach_files_on_creation(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store attach files on creation."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -668,6 +720,7 @@ def test_openai_vector_store_attach_files_on_creation(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -700,8 +753,9 @@ def test_openai_vector_store_attach_files_on_creation(
|
|||
assert updated_vector_store.file_counts.failed == 0
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_list_files(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store list files."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -714,6 +768,7 @@ def test_openai_vector_store_list_files(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -773,8 +828,9 @@ def test_openai_vector_store_list_files(
|
|||
assert updated_vector_store.file_counts.in_progress == 0
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_list_files_invalid_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store list files with invalid vector store ID."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -789,8 +845,9 @@ def test_openai_vector_store_list_files_invalid_vector_store(
|
|||
compat_client.vector_stores.files.list(vector_store_id="abc123")
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_retrieve_file_contents(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store retrieve file contents."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -803,6 +860,7 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -848,8 +906,9 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
assert file_contents.attributes == attributes
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_delete_file(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store delete file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -862,6 +921,7 @@ def test_openai_vector_store_delete_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -912,8 +972,9 @@ def test_openai_vector_store_delete_file(
|
|||
assert updated_vector_store.file_counts.in_progress == 0
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_delete_file_removes_from_vector_store(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store delete file removes from vector store."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -926,6 +987,7 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -962,8 +1024,9 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
|
|||
assert not search_response.data
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_update_file(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test OpenAI vector store update file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -976,6 +1039,7 @@ def test_openai_vector_store_update_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1017,8 +1081,9 @@ def test_openai_vector_store_update_file(
|
|||
assert retrieved_file.attributes["foo"] == "baz"
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_create_vector_store_files_duplicate_vector_store_name(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""
|
||||
This test confirms that client.vector_stores.create() creates a unique ID
|
||||
|
|
@ -1044,6 +1109,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
name="test_store_with_files",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
assert vector_store.file_counts.completed == 0
|
||||
|
|
@ -1056,6 +1122,7 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
name="test_store_with_files",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1086,8 +1153,15 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("search_mode", ["vector", "keyword", "hybrid"])
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_search_modes(
|
||||
llama_stack_client, client_with_models, sample_chunks, search_mode, embedding_model_id, embedding_dimension
|
||||
llama_stack_client,
|
||||
client_with_models,
|
||||
sample_chunks,
|
||||
search_mode,
|
||||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
vector_io_provider_id,
|
||||
):
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
|
||||
|
|
@ -1097,6 +1171,7 @@ def test_openai_vector_store_search_modes(
|
|||
metadata={"purpose": "search_mode_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1115,8 +1190,9 @@ def test_openai_vector_store_search_modes(
|
|||
assert search_response is not None
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_batch_create_and_retrieve(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test creating and retrieving a vector store file batch."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1128,6 +1204,7 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
|||
name="batch_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1178,8 +1255,9 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
|||
assert retrieved_batch.status == "completed" # Should be completed after processing
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_batch_list_files(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test listing files in a vector store file batch."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1191,6 +1269,7 @@ def test_openai_vector_store_file_batch_list_files(
|
|||
name="batch_list_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1271,8 +1350,9 @@ def test_openai_vector_store_file_batch_list_files(
|
|||
assert first_page_ids.isdisjoint(second_page_ids)
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_batch_cancel(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test cancelling a vector store file batch."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1284,6 +1364,7 @@ def test_openai_vector_store_file_batch_cancel(
|
|||
name="batch_cancel_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1326,8 +1407,9 @@ def test_openai_vector_store_file_batch_cancel(
|
|||
assert final_batch.status in ["completed", "cancelled"]
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_batch_retrieve_contents(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test retrieving file contents after file batch processing."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1339,6 +1421,7 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
name="batch_contents_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1399,8 +1482,9 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
assert file_data[i][1].decode("utf-8") in content_text
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_batch_error_handling(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test error handling for file batch operations."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1412,6 +1496,7 @@ def test_openai_vector_store_file_batch_error_handling(
|
|||
name="batch_error_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1456,8 +1541,9 @@ def test_openai_vector_store_file_batch_error_handling(
|
|||
)
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_embedding_config_from_metadata(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test that embedding configuration works from metadata source."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
|
@ -1471,6 +1557,9 @@ def test_openai_vector_store_embedding_config_from_metadata(
|
|||
"embedding_dimension": str(embedding_dimension),
|
||||
"test_source": "metadata",
|
||||
},
|
||||
extra_body={
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
assert vector_store_metadata is not None
|
||||
|
|
@ -1489,6 +1578,7 @@ def test_openai_vector_store_embedding_config_from_metadata(
|
|||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": int(embedding_dimension), # Ensure same type/value
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import pytest
|
|||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
from ..conftest import vector_provider_wrapper
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
|
|
@ -46,12 +48,13 @@ def client_with_empty_registry(client_with_models):
|
|||
clear_registry()
|
||||
|
||||
|
||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
@vector_provider_wrapper
|
||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||
vector_db_name = "test_vector_db"
|
||||
create_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -65,12 +68,13 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe
|
|||
assert response.id.startswith("vs_")
|
||||
|
||||
|
||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
@vector_provider_wrapper
|
||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||
vector_db_name = "test_vector_db"
|
||||
response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -100,12 +104,15 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
|||
("How does machine learning improve over time?", "doc2"),
|
||||
],
|
||||
)
|
||||
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
||||
@vector_provider_wrapper
|
||||
def test_insert_chunks(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
|
||||
):
|
||||
vector_db_name = "test_vector_db"
|
||||
create_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -135,7 +142,10 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
|||
assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}"
|
||||
|
||||
|
||||
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
@vector_provider_wrapper
|
||||
def test_insert_chunks_with_precomputed_embeddings(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
vector_io_provider_params_dict = {
|
||||
"inline::milvus": {"score_threshold": -1.0},
|
||||
"inline::qdrant": {"score_threshold": -1.0},
|
||||
|
|
@ -145,7 +155,7 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
register_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -181,8 +191,9 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
|
||||
|
||||
# expect this test to fail
|
||||
@vector_provider_wrapper
|
||||
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
vector_io_provider_params_dict = {
|
||||
"inline::milvus": {"score_threshold": 0.0},
|
||||
|
|
@ -194,6 +205,7 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -226,33 +238,44 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
||||
|
||||
def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id):
|
||||
@vector_provider_wrapper
|
||||
def test_auto_extract_embedding_dimension(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
# This test specifically tests embedding model override, so we keep embedding_model
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_auto_extract", extra_body={"embedding_model": embedding_model_id}
|
||||
name="test_auto_extract",
|
||||
extra_body={"embedding_model": embedding_model_id, "provider_id": vector_io_provider_id},
|
||||
)
|
||||
assert vs.id is not None
|
||||
|
||||
|
||||
def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id):
|
||||
@vector_provider_wrapper
|
||||
def test_provider_auto_selection_single_provider(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
if len(providers) != 1:
|
||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_auto_provider", extra_body={"embedding_model": embedding_model_id}
|
||||
)
|
||||
# Test that when only one provider is available, it's auto-selected (no provider_id needed)
|
||||
vs = client_with_empty_registry.vector_stores.create(name="test_auto_provider")
|
||||
assert vs.id is not None
|
||||
|
||||
|
||||
def test_provider_id_override(client_with_empty_registry, embedding_model_id):
|
||||
@vector_provider_wrapper
|
||||
def test_provider_id_override(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
if len(providers) != 1:
|
||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||
|
||||
provider_id = providers[0].provider_id
|
||||
|
||||
# Test explicit provider_id specification (using default embedding model)
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id}
|
||||
name="test_provider_override", extra_body={"provider_id": provider_id}
|
||||
)
|
||||
assert vs.id is not None
|
||||
assert vs.metadata.get("provider_id") == provider_id
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestVectorStoresValidation:
|
|||
async def test_validate_missing_model(self):
|
||||
"""Test validation fails when model not found."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="missing")
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="missing")
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = []
|
||||
|
|
@ -31,7 +31,7 @@ class TestVectorStoresValidation:
|
|||
async def test_validate_success(self):
|
||||
"""Test validation passes with valid model."""
|
||||
run_config = StackRunConfig(
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="valid")
|
||||
image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="valid")
|
||||
)
|
||||
mock_models = AsyncMock()
|
||||
mock_models.list_models.return_value = [
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_filters_single_api():
|
||||
"""Test filtering keeps only specified provider for one API."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
||||
|
||||
|
||||
def test_filters_multiple_apis():
|
||||
"""Test filtering multiple APIs."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
BuildProvider(provider_type="remote::anthropic"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
||||
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
||||
|
||||
|
||||
def test_provider_not_found_exits():
|
||||
"""Test error when specified provider doesn't exist."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
||||
|
||||
|
||||
def test_invalid_format_exits():
|
||||
"""Test error for invalid filter format."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "invalid_format")
|
||||
Loading…
Add table
Add a link
Reference in a new issue