mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
Start auto-generating { build, run, doc.md } for distributions
This commit is contained in:
parent
20bf2f50c2
commit
cfa913fdd5
11 changed files with 362 additions and 23 deletions
|
@ -13,20 +13,15 @@ apis:
|
|||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama0
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: http://127.0.0.1:14343
|
||||
url: ${env.OLLAMA_URL:http://127.0.0.1:11434}
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -43,3 +38,10 @@ providers:
|
|||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
models:
|
||||
- model_id: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
||||
provider_id: ollama
|
||||
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
provider_id: ollama
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
|
|
|
@ -13,20 +13,15 @@ apis:
|
|||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama0
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: http://127.0.0.1:14343
|
||||
url: ${env.LLAMA_INFERENCE_OLLAMA_URL:http://127.0.0.1:11434}
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -43,3 +38,10 @@ providers:
|
|||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
models:
|
||||
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
||||
provider_id: ollama
|
||||
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
provider_id: ollama
|
||||
shields:
|
||||
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
|||
provider_type: remote::vllm
|
||||
config:
|
||||
# NOTE: replace with "localhost" if you are running in "host" network mode
|
||||
url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1}
|
||||
url: ${env.VLLM_URL:http://host.docker.internal:5100/v1}
|
||||
max_tokens: ${env.MAX_TOKENS:4096}
|
||||
api_token: fake
|
||||
# serves safety llama_guard model
|
||||
|
@ -24,7 +24,7 @@ providers:
|
|||
provider_type: remote::vllm
|
||||
config:
|
||||
# NOTE: replace with "localhost" if you are running in "host" network mode
|
||||
url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
|
||||
url: ${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
|
||||
max_tokens: ${env.MAX_TOKENS:4096}
|
||||
api_token: fake
|
||||
memory:
|
||||
|
@ -34,7 +34,7 @@ providers:
|
|||
kvstore:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db"
|
||||
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db"
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -50,7 +50,7 @@ providers:
|
|||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db"
|
||||
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db"
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -58,11 +58,11 @@ providers:
|
|||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db"
|
||||
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db"
|
||||
models:
|
||||
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}
|
||||
- model_id: ${env.INFERENCE_MODEL:Llama3.1-8B-Instruct}
|
||||
provider_id: vllm-0
|
||||
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
provider_id: vllm-1
|
||||
shields:
|
||||
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||
|
|
|
@ -313,7 +313,8 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
else:
|
||||
value = default_val
|
||||
|
||||
return value
|
||||
# expand "~" from the values
|
||||
return os.path.expanduser(value)
|
||||
|
||||
try:
|
||||
return re.sub(pattern, get_env_var, config)
|
||||
|
|
|
@ -12,3 +12,11 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls):
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_dict(
|
||||
db_name="agents_store.db"
|
||||
),
|
||||
}
|
||||
|
|
|
@ -34,6 +34,16 @@ class VLLMConfig(BaseModel):
|
|||
default=0.3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls):
|
||||
return {
|
||||
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
|
|
@ -11,3 +11,9 @@ from pydantic import BaseModel
|
|||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls):
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
@ -24,3 +24,12 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls):
|
||||
# TODO: we may need two modes, one for conda and one for docker
|
||||
return {
|
||||
"url": "${env.VLLM_URL:http://host.docker.internal:5100/v1}",
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
}
|
||||
|
|
|
@ -36,6 +36,15 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls):
|
||||
return {
|
||||
"type": "redis",
|
||||
"namespace": None,
|
||||
"host": "${env.REDIS_HOST:localhost}",
|
||||
"port": "${env.REDIS_PORT:6379}",
|
||||
}
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
|
@ -44,6 +53,14 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"namespace": None,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/runtime/" + db_name + "}",
|
||||
}
|
||||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
|
@ -54,6 +71,19 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
password: Optional[str] = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_dict(cls, table_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": "postgres",
|
||||
"namespace": None,
|
||||
"host": "${env.POSTGRES_HOST:localhost}",
|
||||
"port": "${env.POSTGRES_PORT:5432}",
|
||||
"db": "${env.POSTGRES_DB}",
|
||||
"user": "${env.POSTGRES_USER}",
|
||||
"password": "${env.POSTGRES_PASSWORD}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@field_validator("table_name")
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
|
|
5
llama_stack/templates/__init__.py
Normal file
5
llama_stack/templates/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
266
llama_stack/templates/template.py
Normal file
266
llama_stack/templates/template.py
Normal file
|
@ -0,0 +1,266 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import jinja2
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildConfig,
|
||||
DistributionSpec,
|
||||
KVStoreConfig,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
StackRunConfig,
|
||||
)
|
||||
|
||||
|
||||
class DistributionTemplate(BaseModel):
|
||||
"""
|
||||
Represents a Llama Stack distribution instance that can generate configuration
|
||||
and documentation files.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
providers: Dict[str, List[str]]
|
||||
default_models: List[ModelInput]
|
||||
default_shields: Optional[List[ShieldInput]] = None
|
||||
|
||||
# Optional configuration
|
||||
metadata_store: Optional[KVStoreConfig] = None
|
||||
env_vars: Optional[Dict[str, str]] = None
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
@property
|
||||
def distribution_spec(self) -> DistributionSpec:
|
||||
return DistributionSpec(
|
||||
description=self.description,
|
||||
docker_image=self.docker_image,
|
||||
providers=self.providers,
|
||||
)
|
||||
|
||||
def build_config(self) -> BuildConfig:
|
||||
return BuildConfig(
|
||||
name=self.name,
|
||||
distribution_spec=self.distribution_spec,
|
||||
image_type="conda", # default to conda, can be overridden
|
||||
)
|
||||
|
||||
def run_config(self, provider_configs: Dict[str, List[Provider]]) -> StackRunConfig:
|
||||
from datetime import datetime
|
||||
|
||||
# Get unique set of APIs from providers
|
||||
apis: Set[str] = set(self.providers.keys())
|
||||
|
||||
return StackRunConfig(
|
||||
image_name=self.name,
|
||||
docker_image=self.docker_image,
|
||||
built_at=datetime.now(),
|
||||
apis=list(apis),
|
||||
providers=provider_configs,
|
||||
metadata_store=self.metadata_store,
|
||||
models=self.default_models,
|
||||
shields=self.default_shields or [],
|
||||
)
|
||||
|
||||
def generate_markdown_docs(self) -> str:
|
||||
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
|
||||
# First generate the providers table using rich
|
||||
output = StringIO()
|
||||
console = Console(file=output, force_terminal=False)
|
||||
|
||||
table = Table(title="Provider Configuration", show_header=True)
|
||||
table.add_column("API", style="bold")
|
||||
table.add_column("Provider(s)")
|
||||
|
||||
for api, providers in sorted(self.providers.items()):
|
||||
table.add_row(api, ", ".join(f"`{p}`" for p in providers))
|
||||
|
||||
console.print(table)
|
||||
providers_table = output.getvalue()
|
||||
|
||||
# Main documentation template
|
||||
template = """# {{ name }} Distribution
|
||||
|
||||
{{ description }}
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{%- if env_vars %}
|
||||
## Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, description in env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }}
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Using Docker Compose
|
||||
|
||||
```bash
|
||||
$ cd distributions/{{ name }}
|
||||
$ docker compose up
|
||||
```
|
||||
|
||||
### Manual Configuration
|
||||
|
||||
You can also configure the distribution manually by creating a `run.yaml` file:
|
||||
|
||||
```yaml
|
||||
version: '2'
|
||||
image_name: {{ name }}
|
||||
apis:
|
||||
{% for api in providers.keys() %}
|
||||
- {{ api }}
|
||||
{% endfor %}
|
||||
|
||||
providers:
|
||||
{% for api, provider_list in providers.items() %}
|
||||
{{ api }}:
|
||||
{% for provider in provider_list %}
|
||||
- provider_id: {{ provider.lower() }}-0
|
||||
provider_type: {{ provider }}
|
||||
config: {}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
## Models
|
||||
|
||||
The following models are configured by default:
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }}`
|
||||
{% endfor %}
|
||||
|
||||
{%- if default_shields %}
|
||||
|
||||
## Safety Shields
|
||||
|
||||
The following safety shields are configured:
|
||||
{% for shield in default_shields %}
|
||||
- `{{ shield.shield_id }}`
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
"""
|
||||
# Render template with rich-generated table
|
||||
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
||||
template = env.from_string(template)
|
||||
return template.render(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
providers=self.providers,
|
||||
providers_table=providers_table,
|
||||
env_vars=self.env_vars,
|
||||
default_models=self.default_models,
|
||||
default_shields=self.default_shields,
|
||||
)
|
||||
|
||||
def save_distribution(self, output_dir: Path) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save build.yaml
|
||||
build_config = self.build_config()
|
||||
with open(output_dir / "build.yaml", "w") as f:
|
||||
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
# Save run.yaml template
|
||||
# Create a minimal provider config for the template
|
||||
provider_configs = {
|
||||
api: [
|
||||
Provider(
|
||||
provider_id=f"{provider.lower()}-0",
|
||||
provider_type=provider,
|
||||
config={},
|
||||
)
|
||||
for provider in providers
|
||||
]
|
||||
for api, providers in self.providers.items()
|
||||
}
|
||||
run_config = self.run_config(provider_configs)
|
||||
with open(output_dir / "run.yaml", "w") as f:
|
||||
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
# Save documentation
|
||||
docs = self.generate_markdown_docs()
|
||||
with open(output_dir / f"{self.name}.md", "w") as f:
|
||||
f.write(docs)
|
||||
|
||||
@classmethod
|
||||
def vllm_distribution(cls) -> "DistributionTemplate":
|
||||
return cls(
|
||||
name="remote-vllm",
|
||||
description="Use (an external) vLLM server for running LLM inference",
|
||||
providers={
|
||||
"inference": ["remote::vllm"],
|
||||
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
},
|
||||
default_models=[
|
||||
ModelInput(
|
||||
model_id="${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}"
|
||||
),
|
||||
ModelInput(model_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}"),
|
||||
],
|
||||
default_shields=[
|
||||
ShieldInput(shield_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}")
|
||||
],
|
||||
env_vars={
|
||||
"LLAMA_INFERENCE_VLLM_URL": "URL of the vLLM inference server",
|
||||
"LLAMA_SAFETY_VLLM_URL": "URL of the vLLM safety server",
|
||||
"MAX_TOKENS": "Maximum number of tokens for generation",
|
||||
"LLAMA_INFERENCE_MODEL": "Name of the inference model to use",
|
||||
"LLAMA_SAFETY_MODEL": "Name of the safety model to use",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate a distribution template")
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
choices=["vllm"],
|
||||
default="vllm",
|
||||
help="Type of distribution template to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for the distribution files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.type == "vllm":
|
||||
template = DistributionTemplate.vllm_distribution()
|
||||
else:
|
||||
print(f"Unknown template type: {args.type}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
template.save_distribution(args.output_dir)
|
Loading…
Add table
Add a link
Reference in a new issue