llama-stack-mirror/llama_stack/templates/template.py

266 lines
7.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Set
import jinja2
import yaml
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
KVStoreConfig,
ModelInput,
Provider,
ShieldInput,
StackRunConfig,
)
class DistributionTemplate(BaseModel):
"""
Represents a Llama Stack distribution instance that can generate configuration
and documentation files.
"""
name: str
description: str
providers: Dict[str, List[str]]
default_models: List[ModelInput]
default_shields: Optional[List[ShieldInput]] = None
# Optional configuration
metadata_store: Optional[KVStoreConfig] = None
env_vars: Optional[Dict[str, str]] = None
docker_image: Optional[str] = None
@property
def distribution_spec(self) -> DistributionSpec:
return DistributionSpec(
description=self.description,
docker_image=self.docker_image,
providers=self.providers,
)
def build_config(self) -> BuildConfig:
return BuildConfig(
name=self.name,
distribution_spec=self.distribution_spec,
image_type="conda", # default to conda, can be overridden
)
def run_config(self, provider_configs: Dict[str, List[Provider]]) -> StackRunConfig:
from datetime import datetime
# Get unique set of APIs from providers
apis: Set[str] = set(self.providers.keys())
return StackRunConfig(
image_name=self.name,
docker_image=self.docker_image,
built_at=datetime.now(),
apis=list(apis),
providers=provider_configs,
metadata_store=self.metadata_store,
models=self.default_models,
shields=self.default_shields or [],
)
def generate_markdown_docs(self) -> str:
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
# First generate the providers table using rich
output = StringIO()
console = Console(file=output, force_terminal=False)
table = Table(title="Provider Configuration", show_header=True)
table.add_column("API", style="bold")
table.add_column("Provider(s)")
for api, providers in sorted(self.providers.items()):
table.add_row(api, ", ".join(f"`{p}`" for p in providers))
console.print(table)
providers_table = output.getvalue()
# Main documentation template
template = """# {{ name }} Distribution
{{ description }}
## Provider Configuration
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
{%- if env_vars %}
## Environment Variables
The following environment variables can be configured:
{% for var, description in env_vars.items() %}
- `{{ var }}`: {{ description }}
{% endfor %}
{%- endif %}
## Example Usage
### Using Docker Compose
```bash
$ cd distributions/{{ name }}
$ docker compose up
```
### Manual Configuration
You can also configure the distribution manually by creating a `run.yaml` file:
```yaml
version: '2'
image_name: {{ name }}
apis:
{% for api in providers.keys() %}
- {{ api }}
{% endfor %}
providers:
{% for api, provider_list in providers.items() %}
{{ api }}:
{% for provider in provider_list %}
- provider_id: {{ provider.lower() }}-0
provider_type: {{ provider }}
config: {}
{% endfor %}
{% endfor %}
```
## Models
The following models are configured by default:
{% for model in default_models %}
- `{{ model.model_id }}`
{% endfor %}
{%- if default_shields %}
## Safety Shields
The following safety shields are configured:
{% for shield in default_shields %}
- `{{ shield.shield_id }}`
{%- endfor %}
{%- endif %}
"""
# Render template with rich-generated table
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
template = env.from_string(template)
return template.render(
name=self.name,
description=self.description,
providers=self.providers,
providers_table=providers_table,
env_vars=self.env_vars,
default_models=self.default_models,
default_shields=self.default_shields,
)
def save_distribution(self, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
# Save build.yaml
build_config = self.build_config()
with open(output_dir / "build.yaml", "w") as f:
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
# Save run.yaml template
# Create a minimal provider config for the template
provider_configs = {
api: [
Provider(
provider_id=f"{provider.lower()}-0",
provider_type=provider,
config={},
)
for provider in providers
]
for api, providers in self.providers.items()
}
run_config = self.run_config(provider_configs)
with open(output_dir / "run.yaml", "w") as f:
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
# Save documentation
docs = self.generate_markdown_docs()
with open(output_dir / f"{self.name}.md", "w") as f:
f.write(docs)
@classmethod
def vllm_distribution(cls) -> "DistributionTemplate":
return cls(
name="remote-vllm",
description="Use (an external) vLLM server for running LLM inference",
providers={
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
},
default_models=[
ModelInput(
model_id="${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}"
),
ModelInput(model_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}"),
],
default_shields=[
ShieldInput(shield_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}")
],
env_vars={
"LLAMA_INFERENCE_VLLM_URL": "URL of the vLLM inference server",
"LLAMA_SAFETY_VLLM_URL": "URL of the vLLM safety server",
"MAX_TOKENS": "Maximum number of tokens for generation",
"LLAMA_INFERENCE_MODEL": "Name of the inference model to use",
"LLAMA_SAFETY_MODEL": "Name of the safety model to use",
},
)
if __name__ == "__main__":
import argparse
import sys
from pathlib import Path
parser = argparse.ArgumentParser(description="Generate a distribution template")
parser.add_argument(
"--type",
choices=["vllm"],
default="vllm",
help="Type of distribution template to generate",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for the distribution files",
)
args = parser.parse_args()
if args.type == "vllm":
template = DistributionTemplate.vllm_distribution()
else:
print(f"Unknown template type: {args.type}", file=sys.stderr)
sys.exit(1)
template.save_distribution(args.output_dir)