mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
chore: Updating Vector IO integration tests to use llama stack build
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
a701f68bd7
commit
da7b39a3e3
13 changed files with 298 additions and 19 deletions
|
|
@ -144,7 +144,7 @@ jobs:
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync llama stack build --template ci-tests --image-type venv
|
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
|
||||||
|
|
||||||
- name: Check Storage and Memory Available Before Tests
|
- name: Check Storage and Memory Available Before Tests
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
|
|
@ -154,24 +154,23 @@ jobs:
|
||||||
|
|
||||||
- name: Run Vector IO Integration Tests
|
- name: Run Vector IO Integration Tests
|
||||||
env:
|
env:
|
||||||
ENABLE_CHROMADB: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'true' || '' }}
|
# Set environment variables based on provider
|
||||||
|
MILVUS_URL: ${{ matrix.vector-io-provider == 'inline::milvus' && 'dummy' || '' }}
|
||||||
CHROMADB_URL: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'http://localhost:8000' || '' }}
|
CHROMADB_URL: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'http://localhost:8000' || '' }}
|
||||||
ENABLE_PGVECTOR: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'true' || '' }}
|
|
||||||
PGVECTOR_HOST: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'localhost' || '' }}
|
|
||||||
PGVECTOR_PORT: ${{ matrix.vector-io-provider == 'remote::pgvector' && '5432' || '' }}
|
|
||||||
PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||||
PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||||
PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
|
||||||
ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }}
|
|
||||||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
|
||||||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
|
||||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||||
|
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||||
|
FAISS_URL: ${{ matrix.vector-io-provider == 'inline::faiss' && 'dummy' || '' }}
|
||||||
|
SQLITE_VEC_URL: ${{ matrix.vector-io-provider == 'inline::sqlite-vec' && 'dummy' || '' }}
|
||||||
run: |
|
run: |
|
||||||
|
echo "Testing provider: ${{ matrix.vector-io-provider }}"
|
||||||
|
echo "Environment variables set for this provider"
|
||||||
|
|
||||||
uv run --no-sync \
|
uv run --no-sync \
|
||||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
|
||||||
tests/integration/vector_io \
|
tests/integration/vector_io
|
||||||
--embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
|
||||||
--embedding-dimension 768
|
|
||||||
|
|
||||||
- name: Check Storage and Memory Available After Tests
|
- name: Check Storage and Memory Available After Tests
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,84 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||||
|
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||||
|
provider_filters: dict[str, str] = {}
|
||||||
|
for api_provider in single_provider_arg.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider_type = api_provider.split("=")
|
||||||
|
provider_filters[api] = provider_type
|
||||||
|
|
||||||
|
# Create a copy of the build config to modify
|
||||||
|
filtered_build_config = BuildConfig(
|
||||||
|
image_type=build_config.image_type,
|
||||||
|
image_name=build_config.image_name,
|
||||||
|
external_providers_dir=build_config.external_providers_dir,
|
||||||
|
external_apis_dir=build_config.external_apis_dir,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={},
|
||||||
|
description=build_config.distribution_spec.description,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy all providers, but filter the specified APIs
|
||||||
|
for api, providers in build_config.distribution_spec.providers.items():
|
||||||
|
if api in provider_filters:
|
||||||
|
target_provider_type = provider_filters[api]
|
||||||
|
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||||
|
if not filtered_providers:
|
||||||
|
cprint(
|
||||||
|
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||||
|
else:
|
||||||
|
# Keep all providers for unfiltered APIs
|
||||||
|
filtered_build_config.distribution_spec.providers[api] = providers
|
||||||
|
|
||||||
|
return filtered_build_config
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_filtered_run_config(
|
||||||
|
build_config: BuildConfig,
|
||||||
|
build_dir: Path,
|
||||||
|
distro_name: str,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||||
|
and filtering the providers according to the build_config.
|
||||||
|
"""
|
||||||
|
# Load the original distribution's run.yaml
|
||||||
|
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||||
|
|
||||||
|
with importlib.resources.as_file(distro_resource) as path:
|
||||||
|
with open(path) as f:
|
||||||
|
original_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Apply provider filtering to the loaded config
|
||||||
|
for api, providers in build_config.distribution_spec.providers.items():
|
||||||
|
if api in original_config.get("providers", {}):
|
||||||
|
# Filter this API to only include the providers from build_config
|
||||||
|
provider_types = {p.provider_type for p in providers}
|
||||||
|
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||||
|
original_config["providers"][api] = filtered_providers
|
||||||
|
|
||||||
|
# Write the filtered run config
|
||||||
|
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||||
|
with open(run_config_file, "w") as f:
|
||||||
|
yaml.dump(original_config, f, sort_keys=False)
|
||||||
|
|
||||||
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -93,6 +171,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
build_config = available_distros[distro_name]
|
build_config = available_distros[distro_name]
|
||||||
|
|
||||||
|
# Apply single-provider filtering if specified
|
||||||
|
if args.single_provider:
|
||||||
|
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||||
|
|
||||||
if args.image_type:
|
if args.image_type:
|
||||||
build_config.image_type = args.image_type
|
build_config.image_type = args.image_type
|
||||||
else:
|
else:
|
||||||
|
|
@ -245,6 +328,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
config_path=args.config,
|
config_path=args.config,
|
||||||
distro_name=distro_name,
|
distro_name=distro_name,
|
||||||
|
is_filtered=bool(args.single_provider),
|
||||||
)
|
)
|
||||||
|
|
||||||
except (Exception, RuntimeError) as exc:
|
except (Exception, RuntimeError) as exc:
|
||||||
|
|
@ -363,6 +447,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
distro_name: str | None = None,
|
distro_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
|
is_filtered: bool = False,
|
||||||
) -> Path | Traversable:
|
) -> Path | Traversable:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
|
@ -435,12 +520,19 @@ def _run_stack_build_command_from_build_config(
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
||||||
if distro_name:
|
if distro_name:
|
||||||
|
# If single-provider filtering was applied, generate a filtered run config
|
||||||
|
# Otherwise, copy run.yaml from distribution as before
|
||||||
|
if is_filtered:
|
||||||
|
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||||
|
distro_path = run_config_file # Use the generated file as the distro_path
|
||||||
|
else:
|
||||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||||
|
|
||||||
with importlib.resources.as_file(distro_path) as path:
|
with importlib.resources.as_file(distro_resource) as path:
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found.
|
||||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--single-provider",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
# can be fast to load and reduces dependencies
|
# can be fast to load and reduces dependencies
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::weaviate
|
||||||
|
- provider_type: remote::qdrant
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::weaviate
|
||||||
|
- provider_type: remote::qdrant
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::weaviate
|
||||||
|
- provider_type: remote::qdrant
|
||||||
files:
|
files:
|
||||||
- provider_type: inline::localfs
|
- provider_type: inline::localfs
|
||||||
safety:
|
safety:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,21 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
|
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||||
|
provider_type: remote::weaviate
|
||||||
|
config:
|
||||||
|
weaviate_api_key: null
|
||||||
|
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
|
||||||
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
api_key: ${env.QDRANT_API_KEY:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,6 +115,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="inline::milvus"),
|
BuildProvider(provider_type="inline::milvus"),
|
||||||
BuildProvider(provider_type="remote::chromadb"),
|
BuildProvider(provider_type="remote::chromadb"),
|
||||||
BuildProvider(provider_type="remote::pgvector"),
|
BuildProvider(provider_type="remote::pgvector"),
|
||||||
|
BuildProvider(provider_type="remote::weaviate"),
|
||||||
|
BuildProvider(provider_type="remote::qdrant"),
|
||||||
],
|
],
|
||||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||||
"safety": [
|
"safety": [
|
||||||
|
|
@ -221,6 +225,16 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
|
config=WeaviateVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||||
|
provider_type="remote::qdrant",
|
||||||
|
config=QdrantVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ class JobStatus(Enum):
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
|
|
||||||
|
|
||||||
type JobID = str
|
JobID = str
|
||||||
type JobType = str
|
JobType = str
|
||||||
|
|
||||||
|
|
||||||
class JobArtifact(BaseModel):
|
class JobArtifact(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,29 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
"text_model": "groq/llama-3.3-70b-versatile",
|
"text_model": "groq/llama-3.3-70b-versatile",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"milvus": Setup(
|
||||||
|
name="milvus",
|
||||||
|
description="Milvus vector database provider for vector_io tests",
|
||||||
|
env={
|
||||||
|
"MILVUS_URL": "dummy",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"chromadb": Setup(
|
||||||
|
name="chromadb",
|
||||||
|
description="ChromaDB vector database provider for vector_io tests",
|
||||||
|
env={
|
||||||
|
"CHROMADB_URL": "http://localhost:8000",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"pgvector": Setup(
|
||||||
|
name="pgvector",
|
||||||
|
description="PGVector database provider for vector_io tests",
|
||||||
|
env={
|
||||||
|
"PGVECTOR_DB": "llama_stack_test",
|
||||||
|
"PGVECTOR_USER": "postgres",
|
||||||
|
"PGVECTOR_PASSWORD": "password",
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -179,4 +202,9 @@ SUITE_DEFINITIONS: dict[str, Suite] = {
|
||||||
roots=["tests/integration/inference/test_vision_inference.py"],
|
roots=["tests/integration/inference/test_vision_inference.py"],
|
||||||
default_setup="ollama-vision",
|
default_setup="ollama-vision",
|
||||||
),
|
),
|
||||||
|
"vector_io": Suite(
|
||||||
|
name="vector_io",
|
||||||
|
roots=["tests/integration/vector_io"],
|
||||||
|
default_setup="milvus",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
88
tests/unit/distribution/test_single_provider_filter.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.cli.stack._build import _apply_single_provider_filter
|
||||||
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||||
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_single_api():
|
||||||
|
"""Test filtering keeps only specified provider for one API."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={
|
||||||
|
"vector_io": [
|
||||||
|
BuildProvider(provider_type="inline::faiss"),
|
||||||
|
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||||
|
],
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(provider_type="remote::openai"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec")
|
||||||
|
|
||||||
|
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec"
|
||||||
|
assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_multiple_apis():
|
||||||
|
"""Test filtering multiple APIs."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={
|
||||||
|
"vector_io": [
|
||||||
|
BuildProvider(provider_type="inline::faiss"),
|
||||||
|
BuildProvider(provider_type="inline::sqlite-vec"),
|
||||||
|
],
|
||||||
|
"inference": [
|
||||||
|
BuildProvider(provider_type="remote::openai"),
|
||||||
|
BuildProvider(provider_type="remote::anthropic"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai")
|
||||||
|
|
||||||
|
assert len(filtered.distribution_spec.providers["vector_io"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss"
|
||||||
|
assert len(filtered.distribution_spec.providers["inference"]) == 1
|
||||||
|
assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai"
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_not_found_exits():
|
||||||
|
"""Test error when specified provider doesn't exist."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]},
|
||||||
|
description="Test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
_apply_single_provider_filter(build_config, "vector_io=inline::nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_format_exits():
|
||||||
|
"""Test error for invalid filter format."""
|
||||||
|
build_config = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.VENV.value,
|
||||||
|
distribution_spec=DistributionSpec(providers={}, description="Test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
_apply_single_provider_filter(build_config, "invalid_format")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue