mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
chore: Updating Vector IO integration tests to use llama stack build
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
a701f68bd7
commit
da7b39a3e3
13 changed files with 298 additions and 19 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
|||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
||||
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
||||
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
|
|
@ -154,24 +154,23 @@ jobs:
|
|||
|
||||
- name: Run Vector IO Integration Tests
|
||||
env:
|
||||
ENABLE_CHROMADB: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'true' || '' }}
|
||||
# Set environment variables based on provider
|
||||
MILVUS_URL: ${{ matrix.vector-io-provider == 'inline::milvus' && 'dummy' || '' }}
|
||||
CHROMADB_URL: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'http://localhost:8000' || '' }}
|
||||
ENABLE_PGVECTOR: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'true' || '' }}
|
||||
PGVECTOR_HOST: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'localhost' || '' }}
|
||||
PGVECTOR_PORT: ${{ matrix.vector-io-provider == 'remote::pgvector' && '5432' || '' }}
|
||||
PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||
PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||
PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||
ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }}
|
||||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||
FAISS_URL: ${{ matrix.vector-io-provider == 'inline::faiss' && 'dummy' || '' }}
|
||||
SQLITE_VEC_URL: ${{ matrix.vector-io-provider == 'inline::sqlite-vec' && 'dummy' || '' }}
|
||||
run: |
|
||||
echo "Testing provider: ${{ matrix.vector-io-provider }}"
|
||||
echo "Environment variables set for this provider"
|
||||
|
||||
uv run --no-sync \
|
||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
tests/integration/vector_io \
|
||||
--embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
||||
--embedding-dimension 768
|
||||
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
||||
tests/integration/vector_io
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
if: ${{ always() }}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,84 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||
provider_filters: dict[str, str] = {}
|
||||
for api_provider in single_provider_arg.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
provider_filters[api] = provider_type
|
||||
|
||||
# Create a copy of the build config to modify
|
||||
filtered_build_config = BuildConfig(
|
||||
image_type=build_config.image_type,
|
||||
image_name=build_config.image_name,
|
||||
external_providers_dir=build_config.external_providers_dir,
|
||||
external_apis_dir=build_config.external_apis_dir,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={},
|
||||
description=build_config.distribution_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
# Copy all providers, but filter the specified APIs
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in provider_filters:
|
||||
target_provider_type = provider_filters[api]
|
||||
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||
if not filtered_providers:
|
||||
cprint(
|
||||
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||
else:
|
||||
# Keep all providers for unfiltered APIs
|
||||
filtered_build_config.distribution_spec.providers[api] = providers
|
||||
|
||||
return filtered_build_config
|
||||
|
||||
|
||||
def _generate_filtered_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
distro_name: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||
and filtering the providers according to the build_config.
|
||||
"""
|
||||
# Load the original distribution's run.yaml
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
with open(path) as f:
|
||||
original_config = yaml.safe_load(f)
|
||||
|
||||
# Apply provider filtering to the loaded config
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in original_config.get("providers", {}):
|
||||
# Filter this API to only include the providers from build_config
|
||||
provider_types = {p.provider_type for p in providers}
|
||||
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||
original_config["providers"][api] = filtered_providers
|
||||
|
||||
# Write the filtered run config
|
||||
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||
with open(run_config_file, "w") as f:
|
||||
yaml.dump(original_config, f, sort_keys=False)
|
||||
|
||||
return run_config_file
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
|
@ -93,6 +171,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
build_config = available_distros[distro_name]
|
||||
|
||||
# Apply single-provider filtering if specified
|
||||
if args.single_provider:
|
||||
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
|
|
@ -245,6 +328,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
distro_name=distro_name,
|
||||
is_filtered=bool(args.single_provider),
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
|
|
@ -363,6 +447,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
distro_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
is_filtered: bool = False,
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
@ -435,12 +520,19 @@ def _run_stack_build_command_from_build_config(
|
|||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if distro_name:
|
||||
# If single-provider filtering was applied, generate a filtered run config
|
||||
# Otherwise, copy run.yaml from distribution as before
|
||||
if is_filtered:
|
||||
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||
distro_path = run_config_file # Use the generated file as the distro_path
|
||||
else:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found.
|
|||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--single-provider",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
|
|
@ -113,6 +115,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="remote::weaviate"),
|
||||
BuildProvider(provider_type="remote::qdrant"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [
|
||||
|
|
@ -221,6 +225,16 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ class JobStatus(Enum):
|
|||
completed = "completed"
|
||||
|
||||
|
||||
type JobID = str
|
||||
type JobType = str
|
||||
JobID = str
|
||||
JobType = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
|
|
|
|||
|
|
@ -153,6 +153,29 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
|||
"text_model": "groq/llama-3.3-70b-versatile",
|
||||
},
|
||||
),
|
||||
"milvus": Setup(
|
||||
name="milvus",
|
||||
description="Milvus vector database provider for vector_io tests",
|
||||
env={
|
||||
"MILVUS_URL": "dummy",
|
||||
},
|
||||
),
|
||||
"chromadb": Setup(
|
||||
name="chromadb",
|
||||
description="ChromaDB vector database provider for vector_io tests",
|
||||
env={
|
||||
"CHROMADB_URL": "http://localhost:8000",
|
||||
},
|
||||
),
|
||||
"pgvector": Setup(
|
||||
name="pgvector",
|
||||
description="PGVector database provider for vector_io tests",
|
||||
env={
|
||||
"PGVECTOR_DB": "llama_stack_test",
|
||||
"PGVECTOR_USER": "postgres",
|
||||
"PGVECTOR_PASSWORD": "password",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -179,4 +202,9 @@ SUITE_DEFINITIONS: dict[str, Suite] = {
|
|||
roots=["tests/integration/inference/test_vision_inference.py"],
|
||||
default_setup="ollama-vision",
|
||||
),
|
||||
"vector_io": Suite(
|
||||
name="vector_io",
|
||||
roots=["tests/integration/vector_io"],
|
||||
default_setup="milvus",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_filters_single_api():
|
||||
"""Test filtering keeps only specified provider for one API."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
||||
|
||||
|
||||
def test_filters_multiple_apis():
|
||||
"""Test filtering multiple APIs."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||
],
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::openai"),
|
||||
BuildProvider(provider_type="remote::anthropic"),
|
||||
],
|
||||
},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
||||
|
||||
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
||||
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
||||
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
||||
|
||||
|
||||
def test_provider_not_found_exits():
|
||||
"""Test error when specified provider doesn't exist."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
||||
description="Test",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
||||
|
||||
|
||||
def test_invalid_format_exits():
|
||||
"""Test error for invalid filter format."""
|
||||
build_config = BuildConfig(
|
||||
image_type=LlamaStackImageType.VENV.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
_apply_single_provider_filter(build_config, "invalid_format")
|
||||
Loading…
Add table
Add a link
Reference in a new issue