mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 00:52:38 +00:00
chore: Updating Vector IO integration tests to use llama stack build
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
a701f68bd7
commit
da7b39a3e3
13 changed files with 298 additions and 19 deletions
|
|
@ -50,6 +50,84 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
||||
def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
|
||||
"""Filter a distribution to only include specified providers for certain APIs."""
|
||||
provider_filters: dict[str, str] = {}
|
||||
for api_provider in single_provider_arg.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
provider_filters[api] = provider_type
|
||||
|
||||
# Create a copy of the build config to modify
|
||||
filtered_build_config = BuildConfig(
|
||||
image_type=build_config.image_type,
|
||||
image_name=build_config.image_name,
|
||||
external_providers_dir=build_config.external_providers_dir,
|
||||
external_apis_dir=build_config.external_apis_dir,
|
||||
distribution_spec=DistributionSpec(
|
||||
providers={},
|
||||
description=build_config.distribution_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
# Copy all providers, but filter the specified APIs
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in provider_filters:
|
||||
target_provider_type = provider_filters[api]
|
||||
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
|
||||
if not filtered_providers:
|
||||
cprint(
|
||||
f"Provider {target_provider_type} not found in distribution for API {api}",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
filtered_build_config.distribution_spec.providers[api] = filtered_providers
|
||||
else:
|
||||
# Keep all providers for unfiltered APIs
|
||||
filtered_build_config.distribution_spec.providers[api] = providers
|
||||
|
||||
return filtered_build_config
|
||||
|
||||
|
||||
def _generate_filtered_run_config(
|
||||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
distro_name: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a filtered run.yaml by starting with the original distribution's run.yaml
|
||||
and filtering the providers according to the build_config.
|
||||
"""
|
||||
# Load the original distribution's run.yaml
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
with open(path) as f:
|
||||
original_config = yaml.safe_load(f)
|
||||
|
||||
# Apply provider filtering to the loaded config
|
||||
for api, providers in build_config.distribution_spec.providers.items():
|
||||
if api in original_config.get("providers", {}):
|
||||
# Filter this API to only include the providers from build_config
|
||||
provider_types = {p.provider_type for p in providers}
|
||||
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
|
||||
original_config["providers"][api] = filtered_providers
|
||||
|
||||
# Write the filtered run config
|
||||
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
|
||||
with open(run_config_file, "w") as f:
|
||||
yaml.dump(original_config, f, sort_keys=False)
|
||||
|
||||
return run_config_file
|
||||
|
||||
|
||||
@lru_cache
|
||||
def available_distros_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
|
@ -93,6 +171,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
build_config = available_distros[distro_name]
|
||||
|
||||
# Apply single-provider filtering if specified
|
||||
if args.single_provider:
|
||||
build_config = _apply_single_provider_filter(build_config, args.single_provider)
|
||||
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
else:
|
||||
|
|
@ -245,6 +328,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
image_name=image_name,
|
||||
config_path=args.config,
|
||||
distro_name=distro_name,
|
||||
is_filtered=bool(args.single_provider),
|
||||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
|
|
@ -363,6 +447,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
distro_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
is_filtered: bool = False,
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
@ -435,12 +520,19 @@ def _run_stack_build_command_from_build_config(
|
|||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
||||
if distro_name:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
# If single-provider filtering was applied, generate a filtered run config
|
||||
# Otherwise, copy run.yaml from distribution as before
|
||||
if is_filtered:
|
||||
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
|
||||
distro_path = run_config_file # Use the generated file as the distro_path
|
||||
else:
|
||||
# copy run.yaml from distribution to build_dir instead of generating it again
|
||||
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
|
||||
run_config_file = build_dir / f"{distro_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(distro_path) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
with importlib.resources.as_file(distro_resource) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
distro_path = run_config_file # Update distro_path to point to the copied file
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found.
|
|||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--single-provider",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
# can be fast to load and reduces dependencies
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::weaviate
|
||||
- provider_type: remote::qdrant
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
|
|
|
|||
|
|
@ -128,6 +128,21 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
||||
|
|
@ -113,6 +115,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="remote::weaviate"),
|
||||
BuildProvider(provider_type="remote::qdrant"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [
|
||||
|
|
@ -221,6 +225,16 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ class JobStatus(Enum):
|
|||
completed = "completed"
|
||||
|
||||
|
||||
type JobID = str
|
||||
type JobType = str
|
||||
JobID = str
|
||||
JobType = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue