mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 00:47:00 +00:00
# What does this PR do?
Today, external providers are installed via the `external_providers_dir`
in the config. This necessitates users to understand the `ProviderSpec`
and set up their directories accordingly. This process splits up the
config for the stack across multiple files, directories, and formats.
Most (if not all) external providers today have a
[get_provider_spec](559cb18fbb/src/ramalama_stack/provider.py (L9)
)
method that sits unused. Utilizing this method rather than the
providers.d route allows for a much easier installation process for
external providers and limits the amount of extra configuration a
regular user has to do to get their stack off the ground.
To accomplish this and wire it throughout the build process, Introduce
the concept of a `module` for users to specify for an external provider
upon build time. In order to facilitate this, align the build and run
spec to use `Provider` class rather than the stringified provider_type
that build currently uses.
For example, say this is in your build config:
```
- provider_id: ramalama
provider_type: remote::ramalama
module: ramalama_stack
```
during build (in the various `build_...` scripts), additionally to
installing any pip dependencies we will also install this module and use
the `get_provider_spec` method to retrieve the ProviderSpec that is
currently specified using `providers.d`.
In production so far, providing instructions for installing external
providers for users has been difficult: they need to install the module
as a pre-req, create the providers.d directory, copy in the provider
spec, and also copy in the necessary build/run yaml files. Accessing an
external provider should be as easy as possible, and pointing to its
installable module aligns more with the rest of our build and dependency
management process.
For now, `external_providers_dir` still exists as an alternate more
declarative method of using external providers.
## Test Plan
added an integration test installing an external provider from module
and more unit test coverage for `get_provider_registry`
( the warning in yellow is expected, the module is installed inside of
the build env, not where we are running the command)
<img width="1119" height="400" alt="Screenshot 2025-07-24 at 11 30
48 AM"
src="https://github.com/user-attachments/assets/1efbaf45-b9e8-451a-bd63-264ed664706d"
/>
<img width="1154" height="618" alt="Screenshot 2025-07-24 at 11 31
14 AM"
src="https://github.com/user-attachments/assets/feb2b3ea-c5dd-418e-9662-9a3bd5dd6bdc"
/>
---------
Signed-off-by: Charlie Doern <cdoern@redhat.com>
404 lines
16 KiB
Python
404 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import jinja2
|
|
import rich
|
|
import yaml
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.datasets import DatasetPurpose
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.distribution.datatypes import (
|
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
|
Api,
|
|
BenchmarkInput,
|
|
BuildConfig,
|
|
DatasetInput,
|
|
DistributionSpec,
|
|
ModelInput,
|
|
Provider,
|
|
ShieldInput,
|
|
ToolGroupInput,
|
|
)
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
|
|
|
|
|
def filter_empty_values(obj: Any) -> Any:
|
|
"""Recursively filter out specific empty values from a dictionary or list.
|
|
|
|
This function removes:
|
|
- Empty strings ('') only when they are the 'module' field
|
|
- Empty dictionaries ({}) only when they are the 'config' field
|
|
- None values (always excluded)
|
|
"""
|
|
if obj is None:
|
|
return None
|
|
|
|
if isinstance(obj, dict):
|
|
filtered = {}
|
|
for key, value in obj.items():
|
|
# Special handling for specific fields
|
|
if key == "module" and isinstance(value, str) and value == "":
|
|
# Skip empty module strings
|
|
continue
|
|
elif key == "config" and isinstance(value, dict) and not value:
|
|
# Skip empty config dictionaries
|
|
continue
|
|
elif key == "container_image" and not value:
|
|
# Skip empty container_image names
|
|
continue
|
|
else:
|
|
# For all other fields, recursively filter but preserve empty values
|
|
filtered_value = filter_empty_values(value)
|
|
# if filtered_value is not None:
|
|
filtered[key] = filtered_value
|
|
return filtered
|
|
|
|
elif isinstance(obj, list):
|
|
filtered = []
|
|
for item in obj:
|
|
filtered_item = filter_empty_values(item)
|
|
if filtered_item is not None:
|
|
filtered.append(filtered_item)
|
|
return filtered
|
|
|
|
else:
|
|
# For all other types (including empty strings and dicts that aren't module/config),
|
|
# preserve them as-is
|
|
return obj
|
|
|
|
|
|
def get_model_registry(
|
|
available_models: dict[str, list[ProviderModelEntry]],
|
|
) -> tuple[list[ModelInput], bool]:
|
|
models = []
|
|
|
|
# check for conflicts in model ids
|
|
all_ids = set()
|
|
ids_conflict = False
|
|
|
|
for _, entries in available_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
if model_id in all_ids:
|
|
ids_conflict = True
|
|
rich.print(
|
|
f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
|
|
)
|
|
break
|
|
all_ids.update(ids)
|
|
if ids_conflict:
|
|
break
|
|
if ids_conflict:
|
|
break
|
|
|
|
for provider_id, entries in available_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
|
models.append(
|
|
ModelInput(
|
|
model_id=identifier,
|
|
provider_model_id=entry.provider_model_id,
|
|
provider_id=provider_id,
|
|
model_type=entry.model_type,
|
|
metadata=entry.metadata,
|
|
)
|
|
)
|
|
return models, ids_conflict
|
|
|
|
|
|
def get_shield_registry(
|
|
available_safety_models: dict[str, list[ProviderModelEntry]],
|
|
ids_conflict_in_models: bool,
|
|
) -> list[ShieldInput]:
|
|
shields = []
|
|
|
|
# check for conflicts in shield ids
|
|
all_ids = set()
|
|
ids_conflict = False
|
|
|
|
for _, entries in available_safety_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
if model_id in all_ids:
|
|
ids_conflict = True
|
|
rich.print(
|
|
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
|
)
|
|
break
|
|
all_ids.update(ids)
|
|
if ids_conflict:
|
|
break
|
|
if ids_conflict:
|
|
break
|
|
|
|
for provider_id, entries in available_safety_models.items():
|
|
for entry in entries:
|
|
ids = [entry.provider_model_id] + entry.aliases
|
|
for model_id in ids:
|
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
|
shields.append(
|
|
ShieldInput(
|
|
shield_id=identifier,
|
|
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
|
if ids_conflict_in_models
|
|
else entry.provider_model_id,
|
|
)
|
|
)
|
|
|
|
return shields
|
|
|
|
|
|
class DefaultModel(BaseModel):
|
|
model_id: str
|
|
doc_string: str
|
|
|
|
|
|
class RunConfigSettings(BaseModel):
|
|
provider_overrides: dict[str, list[Provider]] = Field(default_factory=dict)
|
|
default_models: list[ModelInput] | None = None
|
|
default_shields: list[ShieldInput] | None = None
|
|
default_tool_groups: list[ToolGroupInput] | None = None
|
|
default_datasets: list[DatasetInput] | None = None
|
|
default_benchmarks: list[BenchmarkInput] | None = None
|
|
metadata_store: dict | None = None
|
|
inference_store: dict | None = None
|
|
|
|
def run_config(
|
|
self,
|
|
name: str,
|
|
providers: dict[str, list[Provider]],
|
|
container_image: str | None = None,
|
|
) -> dict:
|
|
provider_registry = get_provider_registry()
|
|
provider_configs = {}
|
|
for api_str, provider_objs in providers.items():
|
|
if api_providers := self.provider_overrides.get(api_str):
|
|
# Convert Provider objects to dicts for YAML serialization
|
|
provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers]
|
|
continue
|
|
|
|
provider_configs[api_str] = []
|
|
for provider in provider_objs:
|
|
api = Api(api_str)
|
|
if provider.provider_type not in provider_registry[api]:
|
|
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
|
|
|
|
config_class = provider_registry[api][provider.provider_type].config_class
|
|
assert config_class is not None, (
|
|
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
|
|
)
|
|
|
|
config_class = instantiate_class_type(config_class)
|
|
if hasattr(config_class, "sample_run_config"):
|
|
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
|
|
else:
|
|
config = {}
|
|
|
|
provider.config = config
|
|
# Convert Provider object to dict for YAML serialization
|
|
provider_configs[api_str].append(provider.model_dump(exclude_none=True))
|
|
# Get unique set of APIs from providers
|
|
apis = sorted(providers.keys())
|
|
|
|
# Return a dict that matches StackRunConfig structure
|
|
return {
|
|
"version": LLAMA_STACK_RUN_CONFIG_VERSION,
|
|
"image_name": name,
|
|
"container_image": container_image,
|
|
"apis": apis,
|
|
"providers": provider_configs,
|
|
"metadata_store": self.metadata_store
|
|
or SqliteKVStoreConfig.sample_run_config(
|
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
|
db_name="registry.db",
|
|
),
|
|
"inference_store": self.inference_store
|
|
or SqliteSqlStoreConfig.sample_run_config(
|
|
__distro_dir__=f"~/.llama/distributions/{name}",
|
|
db_name="inference_store.db",
|
|
),
|
|
"models": [m.model_dump(exclude_none=True) for m in (self.default_models or [])],
|
|
"shields": [s.model_dump(exclude_none=True) for s in (self.default_shields or [])],
|
|
"vector_dbs": [],
|
|
"datasets": [d.model_dump(exclude_none=True) for d in (self.default_datasets or [])],
|
|
"scoring_fns": [],
|
|
"benchmarks": [b.model_dump(exclude_none=True) for b in (self.default_benchmarks or [])],
|
|
"tool_groups": [t.model_dump(exclude_none=True) for t in (self.default_tool_groups or [])],
|
|
"server": {
|
|
"port": 8321,
|
|
},
|
|
}
|
|
|
|
|
|
class DistributionTemplate(BaseModel):
|
|
"""
|
|
Represents a Llama Stack distribution instance that can generate configuration
|
|
and documentation files.
|
|
"""
|
|
|
|
name: str
|
|
description: str
|
|
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
|
|
|
providers: dict[str, list[Provider]]
|
|
run_configs: dict[str, RunConfigSettings]
|
|
template_path: Path | None = None
|
|
|
|
# Optional configuration
|
|
run_config_env_vars: dict[str, tuple[str, str]] | None = None
|
|
container_image: str | None = None
|
|
|
|
available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None
|
|
|
|
# we may want to specify additional pip packages without necessarily indicating a
|
|
# specific "default" inference store (which is what typically used to dictate additional
|
|
# pip packages)
|
|
additional_pip_packages: list[str] | None = None
|
|
|
|
def build_config(self) -> BuildConfig:
|
|
additional_pip_packages: list[str] = []
|
|
for run_config in self.run_configs.values():
|
|
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
|
|
|
# TODO: This is a hack to get the dependencies for internal APIs into build
|
|
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
|
# and providers, with a way to specify dependencies for them.
|
|
|
|
if run_config_.get("inference_store"):
|
|
additional_pip_packages.extend(get_sql_pip_packages(run_config_["inference_store"]))
|
|
|
|
if run_config_.get("metadata_store"):
|
|
additional_pip_packages.extend(get_kv_pip_packages(run_config_["metadata_store"]))
|
|
|
|
if self.additional_pip_packages:
|
|
additional_pip_packages.extend(self.additional_pip_packages)
|
|
|
|
# Create minimal providers for build config (without runtime configs)
|
|
build_providers = {}
|
|
for api, providers in self.providers.items():
|
|
build_providers[api] = []
|
|
for provider in providers:
|
|
# Create a minimal provider object with only essential build information
|
|
build_provider = Provider(
|
|
provider_id=provider.provider_id,
|
|
provider_type=provider.provider_type,
|
|
config={}, # Empty config for build
|
|
module=provider.module,
|
|
)
|
|
build_providers[api].append(build_provider)
|
|
|
|
return BuildConfig(
|
|
distribution_spec=DistributionSpec(
|
|
description=self.description,
|
|
container_image=self.container_image,
|
|
providers=build_providers,
|
|
),
|
|
image_type="conda",
|
|
image_name=self.name,
|
|
additional_pip_packages=sorted(set(additional_pip_packages)),
|
|
)
|
|
|
|
def generate_markdown_docs(self) -> str:
|
|
providers_table = "| API | Provider(s) |\n"
|
|
providers_table += "|-----|-------------|\n"
|
|
|
|
for api, providers in sorted(self.providers.items()):
|
|
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
|
|
providers_table += f"| {api} | {providers_str} |\n"
|
|
|
|
template = self.template_path.read_text()
|
|
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
|
|
orphantext = "---\norphan: true\n---\n"
|
|
|
|
if template.startswith(orphantext):
|
|
template = template.replace(orphantext, orphantext + comment)
|
|
else:
|
|
template = comment + template
|
|
|
|
# Render template with rich-generated table
|
|
env = jinja2.Environment(
|
|
trim_blocks=True,
|
|
lstrip_blocks=True,
|
|
# NOTE: autoescape is required to prevent XSS attacks
|
|
autoescape=True,
|
|
)
|
|
template = env.from_string(template)
|
|
|
|
default_models = []
|
|
if self.available_models_by_provider:
|
|
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
|
|
for provider_id, model_entries in self.available_models_by_provider.items():
|
|
for model_entry in model_entries:
|
|
doc_parts = []
|
|
if model_entry.aliases:
|
|
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
|
|
if has_multiple_providers:
|
|
doc_parts.append(f"provider: {provider_id}")
|
|
|
|
default_models.append(
|
|
DefaultModel(
|
|
model_id=model_entry.provider_model_id,
|
|
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
|
)
|
|
)
|
|
|
|
return template.render(
|
|
name=self.name,
|
|
description=self.description,
|
|
providers=self.providers,
|
|
providers_table=providers_table,
|
|
run_config_env_vars=self.run_config_env_vars,
|
|
default_models=default_models,
|
|
)
|
|
|
|
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
|
def enum_representer(dumper, data):
|
|
return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)
|
|
|
|
# Register YAML representer for ModelType
|
|
yaml.add_representer(ModelType, enum_representer)
|
|
yaml.add_representer(DatasetPurpose, enum_representer)
|
|
yaml.SafeDumper.add_representer(ModelType, enum_representer)
|
|
yaml.SafeDumper.add_representer(DatasetPurpose, enum_representer)
|
|
|
|
for output_dir in [yaml_output_dir, doc_output_dir]:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
build_config = self.build_config()
|
|
with open(yaml_output_dir / "build.yaml", "w") as f:
|
|
yaml.safe_dump(
|
|
filter_empty_values(build_config.model_dump(exclude_none=True)),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
for yaml_pth, settings in self.run_configs.items():
|
|
run_config = settings.run_config(self.name, self.providers, self.container_image)
|
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
|
yaml.safe_dump(
|
|
filter_empty_values(run_config),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
if self.template_path:
|
|
docs = self.generate_markdown_docs()
|
|
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
|
f.write(docs if docs.endswith("\n") else docs + "\n")
|