diff --git a/distributions/ollama-gpu/run.yaml b/distributions/ollama-gpu/run.yaml index c702b878e..1d928ec25 100644 --- a/distributions/ollama-gpu/run.yaml +++ b/distributions/ollama-gpu/run.yaml @@ -13,20 +13,15 @@ apis: - safety providers: inference: - - provider_id: ollama0 + - provider_id: ollama provider_type: remote::ollama config: - url: http://127.0.0.1:14343 + url: ${env.OLLAMA_URL:http://127.0.0.1:11434} safety: - provider_id: meta0 provider_type: inline::llama-guard config: - model: Llama-Guard-3-1B excluded_categories: [] - - provider_id: meta1 - provider_type: inline::prompt-guard - config: - model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: inline::meta-reference @@ -43,3 +38,10 @@ providers: - provider_id: meta0 provider_type: inline::meta-reference config: {} +models: + - model_id: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct} + provider_id: ollama + - model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B} + provider_id: ollama +shields: + - shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B} diff --git a/distributions/ollama/run.yaml b/distributions/ollama/run.yaml index c702b878e..461f64609 100644 --- a/distributions/ollama/run.yaml +++ b/distributions/ollama/run.yaml @@ -13,20 +13,15 @@ apis: - safety providers: inference: - - provider_id: ollama0 + - provider_id: ollama provider_type: remote::ollama config: - url: http://127.0.0.1:14343 + url: ${env.LLAMA_INFERENCE_OLLAMA_URL:http://127.0.0.1:11434} safety: - provider_id: meta0 provider_type: inline::llama-guard config: - model: Llama-Guard-3-1B excluded_categories: [] - - provider_id: meta1 - provider_type: inline::prompt-guard - config: - model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: inline::meta-reference @@ -43,3 +38,10 @@ providers: - provider_id: meta0 provider_type: inline::meta-reference config: {} +models: + - model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.2-3B-Instruct} + provider_id: ollama + - model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} + provider_id: ollama +shields: + - shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} diff --git a/distributions/remote-vllm/run.yaml b/distributions/remote-vllm/run.yaml index eae5b8a6f..e6be2bd06 100644 --- a/distributions/remote-vllm/run.yaml +++ b/distributions/remote-vllm/run.yaml @@ -16,7 +16,7 @@ providers: provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1} + url: ${env.VLLM_URL:http://host.docker.internal:5100/v1} max_tokens: ${env.MAX_TOKENS:4096} api_token: fake # serves safety llama_guard model @@ -24,7 +24,7 @@ providers: provider_type: remote::vllm config: # NOTE: replace with "localhost" if you are running in "host" network mode - url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1} + url: ${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1} max_tokens: ${env.MAX_TOKENS:4096} api_token: fake memory: @@ -34,7 +34,7 @@ providers: kvstore: namespace: null type: sqlite - db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db" + db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db" safety: - provider_id: llama-guard provider_type: inline::llama-guard @@ -50,7 +50,7 @@ providers: persistence_store: namespace: null type: sqlite - db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db" + db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db" telemetry: - provider_id: meta0 provider_type: inline::meta-reference @@ -58,11 +58,11 @@ providers: metadata_store: namespace: null type: sqlite - db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db" + db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db" models: - - model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct} + - model_id: ${env.INFERENCE_MODEL:Llama3.1-8B-Instruct} provider_id: vllm-0 - - model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} + - model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B} provider_id: vllm-1 shields: - - shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B} + - shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B} diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 0cfd11eda..7494e9367 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -313,7 +313,8 @@ def replace_env_vars(config: Any, path: str = "") -> Any: else: value = default_val - return value + # expand "~" from the values + return os.path.expanduser(value) try: return re.sub(pattern, get_env_var, config) diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index 2770ed13c..44628758a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -12,3 +12,11 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig()) + + @classmethod + def sample_dict(cls): + return { + "persistence_store": SqliteKVStoreConfig.sample_dict( + db_name="agents_store.db" + ), + } diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index a7469ebde..a633dffb6 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -34,6 +34,16 @@ class VLLMConfig(BaseModel): default=0.3, ) + @classmethod + def sample_dict(cls): + return { + "model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}", + "tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}", + "max_tokens": "${env.VLLM_MAX_TOKENS:4096}", + "enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}", + "gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}", + } + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py index 72036fd1c..4d9e2b969 100644 --- a/llama_stack/providers/inline/safety/llama_guard/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -11,3 +11,9 @@ from pydantic import BaseModel class LlamaGuardConfig(BaseModel): excluded_categories: List[str] = [] + + @classmethod + def sample_dict(cls): + return { + "excluded_categories": [], + } diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 50a174589..8aa7af4f0 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -24,3 +24,12 @@ class VLLMInferenceAdapterConfig(BaseModel): default="fake", description="The API token", ) + + @classmethod + def sample_dict(cls): + # TODO: we may need two modes, one for conda and one for docker + return { + "url": "${env.VLLM_URL:http://host.docker.internal:5100/v1}", + "max_tokens": "${env.VLLM_MAX_TOKENS:4096}", + "api_token": "${env.VLLM_API_TOKEN:fake}", + } diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 0a21bf4ca..5559a99f2 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -36,6 +36,15 @@ class RedisKVStoreConfig(CommonConfig): def url(self) -> str: return f"redis://{self.host}:{self.port}" + @classmethod + def sample_dict(cls): + return { + "type": "redis", + "namespace": None, + "host": "${env.REDIS_HOST:localhost}", + "port": "${env.REDIS_PORT:6379}", + } + class SqliteKVStoreConfig(CommonConfig): type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value @@ -44,6 +53,14 @@ class SqliteKVStoreConfig(CommonConfig): description="File path for the sqlite database", ) + @classmethod + def sample_dict(cls, db_name: str = "kvstore.db"): + return { + "type": "sqlite", + "namespace": None, + "db_path": "${env.SQLITE_STORE_DIR:~/.llama/runtime/" + db_name + "}", + } + class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value @@ -54,6 +71,19 @@ class PostgresKVStoreConfig(CommonConfig): password: Optional[str] = None table_name: str = "llamastack_kvstore" + @classmethod + def sample_dict(cls, table_name: str = "llamastack_kvstore"): + return { + "type": "postgres", + "namespace": None, + "host": "${env.POSTGRES_HOST:localhost}", + "port": "${env.POSTGRES_PORT:5432}", + "db": "${env.POSTGRES_DB}", + "user": "${env.POSTGRES_USER}", + "password": "${env.POSTGRES_PASSWORD}", + "table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}", + } + @classmethod @field_validator("table_name") def validate_table_name(cls, v: str) -> str: diff --git a/llama_stack/templates/__init__.py b/llama_stack/templates/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/templates/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py new file mode 100644 index 000000000..57fcbe962 --- /dev/null +++ b/llama_stack/templates/template.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from io import StringIO + +from pathlib import Path +from typing import Dict, List, Optional, Set + +import jinja2 +import yaml +from pydantic import BaseModel + +from rich.console import Console +from rich.table import Table + +from llama_stack.distribution.datatypes import ( + BuildConfig, + DistributionSpec, + KVStoreConfig, + ModelInput, + Provider, + ShieldInput, + StackRunConfig, +) + + +class DistributionTemplate(BaseModel): + """ + Represents a Llama Stack distribution instance that can generate configuration + and documentation files. + """ + + name: str + description: str + providers: Dict[str, List[str]] + default_models: List[ModelInput] + default_shields: Optional[List[ShieldInput]] = None + + # Optional configuration + metadata_store: Optional[KVStoreConfig] = None + env_vars: Optional[Dict[str, str]] = None + docker_image: Optional[str] = None + + @property + def distribution_spec(self) -> DistributionSpec: + return DistributionSpec( + description=self.description, + docker_image=self.docker_image, + providers=self.providers, + ) + + def build_config(self) -> BuildConfig: + return BuildConfig( + name=self.name, + distribution_spec=self.distribution_spec, + image_type="conda", # default to conda, can be overridden + ) + + def run_config(self, provider_configs: Dict[str, List[Provider]]) -> StackRunConfig: + from datetime import datetime + + # Get unique set of APIs from providers + apis: Set[str] = set(self.providers.keys()) + + return StackRunConfig( + image_name=self.name, + docker_image=self.docker_image, + built_at=datetime.now(), + apis=list(apis), + providers=provider_configs, + metadata_store=self.metadata_store, + models=self.default_models, + shields=self.default_shields or [], + ) + + def generate_markdown_docs(self) -> str: + """Generate markdown documentation using both Jinja2 templates and rich tables.""" + # First generate the providers table using rich + output = StringIO() + console = Console(file=output, force_terminal=False) + + table = Table(title="Provider Configuration", show_header=True) + table.add_column("API", style="bold") + table.add_column("Provider(s)") + + for api, providers in sorted(self.providers.items()): + table.add_row(api, ", ".join(f"`{p}`" for p in providers)) + + console.print(table) + providers_table = output.getvalue() + + # Main documentation template + template = """# {{ name }} Distribution + +{{ description }} + +## Provider Configuration + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations: + +{{ providers_table }} + +{%- if env_vars %} +## Environment Variables + +The following environment variables can be configured: + +{% for var, description in env_vars.items() %} +- `{{ var }}`: {{ description }} +{% endfor %} +{%- endif %} + +## Example Usage + +### Using Docker Compose + +```bash +$ cd distributions/{{ name }} +$ docker compose up +``` + +### Manual Configuration + +You can also configure the distribution manually by creating a `run.yaml` file: + +```yaml +version: '2' +image_name: {{ name }} +apis: +{% for api in providers.keys() %} + - {{ api }} +{% endfor %} + +providers: +{% for api, provider_list in providers.items() %} + {{ api }}: + {% for provider in provider_list %} + - provider_id: {{ provider.lower() }}-0 + provider_type: {{ provider }} + config: {} + {% endfor %} +{% endfor %} +``` + +## Models + +The following models are configured by default: +{% for model in default_models %} +- `{{ model.model_id }}` +{% endfor %} + +{%- if default_shields %} + +## Safety Shields + +The following safety shields are configured: +{% for shield in default_shields %} +- `{{ shield.shield_id }}` +{%- endfor %} +{%- endif %} +""" + # Render template with rich-generated table + env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True) + template = env.from_string(template) + return template.render( + name=self.name, + description=self.description, + providers=self.providers, + providers_table=providers_table, + env_vars=self.env_vars, + default_models=self.default_models, + default_shields=self.default_shields, + ) + + def save_distribution(self, output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Save build.yaml + build_config = self.build_config() + with open(output_dir / "build.yaml", "w") as f: + yaml.safe_dump(build_config.model_dump(), f, sort_keys=False) + + # Save run.yaml template + # Create a minimal provider config for the template + provider_configs = { + api: [ + Provider( + provider_id=f"{provider.lower()}-0", + provider_type=provider, + config={}, + ) + for provider in providers + ] + for api, providers in self.providers.items() + } + run_config = self.run_config(provider_configs) + with open(output_dir / "run.yaml", "w") as f: + yaml.safe_dump(run_config.model_dump(), f, sort_keys=False) + + # Save documentation + docs = self.generate_markdown_docs() + with open(output_dir / f"{self.name}.md", "w") as f: + f.write(docs) + + @classmethod + def vllm_distribution(cls) -> "DistributionTemplate": + return cls( + name="remote-vllm", + description="Use (an external) vLLM server for running LLM inference", + providers={ + "inference": ["remote::vllm"], + "memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + }, + default_models=[ + ModelInput( + model_id="${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}" + ), + ModelInput(model_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}"), + ], + default_shields=[ + ShieldInput(shield_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}") + ], + env_vars={ + "LLAMA_INFERENCE_VLLM_URL": "URL of the vLLM inference server", + "LLAMA_SAFETY_VLLM_URL": "URL of the vLLM safety server", + "MAX_TOKENS": "Maximum number of tokens for generation", + "LLAMA_INFERENCE_MODEL": "Name of the inference model to use", + "LLAMA_SAFETY_MODEL": "Name of the safety model to use", + }, + ) + + +if __name__ == "__main__": + import argparse + import sys + from pathlib import Path + + parser = argparse.ArgumentParser(description="Generate a distribution template") + parser.add_argument( + "--type", + choices=["vllm"], + default="vllm", + help="Type of distribution template to generate", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for the distribution files", + ) + + args = parser.parse_args() + + if args.type == "vllm": + template = DistributionTemplate.vllm_distribution() + else: + print(f"Unknown template type: {args.type}", file=sys.stderr) + sys.exit(1) + + template.save_distribution(args.output_dir)