mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
chore: Updating how default embedding model is set in stack (#3818)
# What does this PR do?
Refactor setting default vector store provider and embedding model to
use an optional `vector_stores` config in the `StackRunConfig` and clean
up code to do so (had to add back in some pieces of VectorDB). Also
added remote Qdrant and Weaviate to starter distro (based on other PR
where inference providers were added for UX).
New config is simply (default for Starter distro):
```yaml
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
```
## Test Plan
CI and Unit tests.
---------
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
2c43285e22
commit
48581bf651
48 changed files with 973 additions and 818 deletions
|
|
@ -317,3 +317,72 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
|||
if p.is_relative_to(rp):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_vector_io_provider_ids(client):
|
||||
"""Get all available vector_io provider IDs."""
|
||||
providers = [p for p in client.providers.list() if p.api == "vector_io"]
|
||||
return [p.provider_id for p in providers]
|
||||
|
||||
|
||||
def vector_provider_wrapper(func):
|
||||
"""Decorator to run a test against all available vector_io providers."""
|
||||
import functools
|
||||
import os
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the vector_io_provider_id from the test arguments
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
vector_io_provider_id = bound_args.arguments.get("vector_io_provider_id")
|
||||
if not vector_io_provider_id:
|
||||
pytest.skip("No vector_io_provider_id provided")
|
||||
|
||||
# Get client_with_models to check available providers
|
||||
client_with_models = bound_args.arguments.get("client_with_models")
|
||||
if client_with_models:
|
||||
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||
if vector_io_provider_id not in available_providers:
|
||||
pytest.skip(f"Provider '{vector_io_provider_id}' not available. Available: {available_providers}")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# For replay tests, only use providers that are available in ci-tests environment
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") == "replay":
|
||||
all_providers = ["faiss", "sqlite-vec"]
|
||||
else:
|
||||
# For live tests, try all providers (they'll skip if not available)
|
||||
all_providers = [
|
||||
"faiss",
|
||||
"sqlite-vec",
|
||||
"milvus",
|
||||
"chromadb",
|
||||
"pgvector",
|
||||
"weaviate",
|
||||
"qdrant",
|
||||
]
|
||||
|
||||
return pytest.mark.parametrize("vector_io_provider_id", all_providers)(wrapper)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_provider_id(request, client_with_models):
|
||||
"""Fixture that provides a specific vector_io provider ID, skipping if not available."""
|
||||
if hasattr(request, "param"):
|
||||
requested_provider = request.param
|
||||
available_providers = get_vector_io_provider_ids(client_with_models)
|
||||
|
||||
if requested_provider not in available_providers:
|
||||
pytest.skip(f"Provider '{requested_provider}' not available. Available: {available_providers}")
|
||||
|
||||
return requested_provider
|
||||
else:
|
||||
provider_ids = get_vector_io_provider_ids(client_with_models)
|
||||
if not provider_ids:
|
||||
pytest.skip("No vector_io providers available")
|
||||
return provider_ids[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue