Start auto-generating { build, run, doc.md } for distributions

This commit is contained in:
Ashwin Bharambe 2024-11-14 17:44:45 -08:00
parent 20bf2f50c2
commit cfa913fdd5
11 changed files with 362 additions and 23 deletions

View file

@ -13,20 +13,15 @@ apis:
- safety
providers:
inference:
- provider_id: ollama0
- provider_id: ollama
provider_type: remote::ollama
config:
url: http://127.0.0.1:14343
url: ${env.OLLAMA_URL:http://127.0.0.1:11434}
safety:
- provider_id: meta0
provider_type: inline::llama-guard
config:
model: Llama-Guard-3-1B
excluded_categories: []
- provider_id: meta1
provider_type: inline::prompt-guard
config:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: inline::meta-reference
@ -43,3 +38,10 @@ providers:
- provider_id: meta0
provider_type: inline::meta-reference
config: {}
models:
- model_id: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
provider_id: ollama
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
provider_id: ollama
shields:
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}

View file

@ -13,20 +13,15 @@ apis:
- safety
providers:
inference:
- provider_id: ollama0
- provider_id: ollama
provider_type: remote::ollama
config:
url: http://127.0.0.1:14343
url: ${env.LLAMA_INFERENCE_OLLAMA_URL:http://127.0.0.1:11434}
safety:
- provider_id: meta0
provider_type: inline::llama-guard
config:
model: Llama-Guard-3-1B
excluded_categories: []
- provider_id: meta1
provider_type: inline::prompt-guard
config:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: inline::meta-reference
@ -43,3 +38,10 @@ providers:
- provider_id: meta0
provider_type: inline::meta-reference
config: {}
models:
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.2-3B-Instruct}
provider_id: ollama
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
provider_id: ollama
shields:
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::vllm
config:
# NOTE: replace with "localhost" if you are running in "host" network mode
url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1}
url: ${env.VLLM_URL:http://host.docker.internal:5100/v1}
max_tokens: ${env.MAX_TOKENS:4096}
api_token: fake
# serves safety llama_guard model
@ -24,7 +24,7 @@ providers:
provider_type: remote::vllm
config:
# NOTE: replace with "localhost" if you are running in "host" network mode
url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
url: ${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
max_tokens: ${env.MAX_TOKENS:4096}
api_token: fake
memory:
@ -34,7 +34,7 @@ providers:
kvstore:
namespace: null
type: sqlite
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db"
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db"
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -50,7 +50,7 @@ providers:
persistence_store:
namespace: null
type: sqlite
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db"
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db"
telemetry:
- provider_id: meta0
provider_type: inline::meta-reference
@ -58,11 +58,11 @@ providers:
metadata_store:
namespace: null
type: sqlite
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db"
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db"
models:
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}
- model_id: ${env.INFERENCE_MODEL:Llama3.1-8B-Instruct}
provider_id: vllm-0
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
provider_id: vllm-1
shields:
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}

View file

@ -313,7 +313,8 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
else:
value = default_val
return value
# expand "~" from the values
return os.path.expanduser(value)
try:
return re.sub(pattern, get_env_var, config)

View file

@ -12,3 +12,11 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
@classmethod
def sample_dict(cls):
return {
"persistence_store": SqliteKVStoreConfig.sample_dict(
db_name="agents_store.db"
),
}

View file

@ -34,6 +34,16 @@ class VLLMConfig(BaseModel):
default=0.3,
)
@classmethod
def sample_dict(cls):
return {
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:

View file

@ -11,3 +11,9 @@ from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
excluded_categories: List[str] = []
@classmethod
def sample_dict(cls):
return {
"excluded_categories": [],
}

View file

@ -24,3 +24,12 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
@classmethod
def sample_dict(cls):
# TODO: we may need two modes, one for conda and one for docker
return {
"url": "${env.VLLM_URL:http://host.docker.internal:5100/v1}",
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
}

View file

@ -36,6 +36,15 @@ class RedisKVStoreConfig(CommonConfig):
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@classmethod
def sample_dict(cls):
return {
"type": "redis",
"namespace": None,
"host": "${env.REDIS_HOST:localhost}",
"port": "${env.REDIS_PORT:6379}",
}
class SqliteKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
@ -44,6 +53,14 @@ class SqliteKVStoreConfig(CommonConfig):
description="File path for the sqlite database",
)
@classmethod
def sample_dict(cls, db_name: str = "kvstore.db"):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/runtime/" + db_name + "}",
}
class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
@ -54,6 +71,19 @@ class PostgresKVStoreConfig(CommonConfig):
password: Optional[str] = None
table_name: str = "llamastack_kvstore"
@classmethod
def sample_dict(cls, table_name: str = "llamastack_kvstore"):
return {
"type": "postgres",
"namespace": None,
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB}",
"user": "${env.POSTGRES_USER}",
"password": "${env.POSTGRES_PASSWORD}",
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
}
@classmethod
@field_validator("table_name")
def validate_table_name(cls, v: str) -> str:

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,266 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Set
import jinja2
import yaml
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
KVStoreConfig,
ModelInput,
Provider,
ShieldInput,
StackRunConfig,
)
class DistributionTemplate(BaseModel):
"""
Represents a Llama Stack distribution instance that can generate configuration
and documentation files.
"""
name: str
description: str
providers: Dict[str, List[str]]
default_models: List[ModelInput]
default_shields: Optional[List[ShieldInput]] = None
# Optional configuration
metadata_store: Optional[KVStoreConfig] = None
env_vars: Optional[Dict[str, str]] = None
docker_image: Optional[str] = None
@property
def distribution_spec(self) -> DistributionSpec:
return DistributionSpec(
description=self.description,
docker_image=self.docker_image,
providers=self.providers,
)
def build_config(self) -> BuildConfig:
return BuildConfig(
name=self.name,
distribution_spec=self.distribution_spec,
image_type="conda", # default to conda, can be overridden
)
def run_config(self, provider_configs: Dict[str, List[Provider]]) -> StackRunConfig:
from datetime import datetime
# Get unique set of APIs from providers
apis: Set[str] = set(self.providers.keys())
return StackRunConfig(
image_name=self.name,
docker_image=self.docker_image,
built_at=datetime.now(),
apis=list(apis),
providers=provider_configs,
metadata_store=self.metadata_store,
models=self.default_models,
shields=self.default_shields or [],
)
def generate_markdown_docs(self) -> str:
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
# First generate the providers table using rich
output = StringIO()
console = Console(file=output, force_terminal=False)
table = Table(title="Provider Configuration", show_header=True)
table.add_column("API", style="bold")
table.add_column("Provider(s)")
for api, providers in sorted(self.providers.items()):
table.add_row(api, ", ".join(f"`{p}`" for p in providers))
console.print(table)
providers_table = output.getvalue()
# Main documentation template
template = """# {{ name }} Distribution
{{ description }}
## Provider Configuration
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
{%- if env_vars %}
## Environment Variables
The following environment variables can be configured:
{% for var, description in env_vars.items() %}
- `{{ var }}`: {{ description }}
{% endfor %}
{%- endif %}
## Example Usage
### Using Docker Compose
```bash
$ cd distributions/{{ name }}
$ docker compose up
```
### Manual Configuration
You can also configure the distribution manually by creating a `run.yaml` file:
```yaml
version: '2'
image_name: {{ name }}
apis:
{% for api in providers.keys() %}
- {{ api }}
{% endfor %}
providers:
{% for api, provider_list in providers.items() %}
{{ api }}:
{% for provider in provider_list %}
- provider_id: {{ provider.lower() }}-0
provider_type: {{ provider }}
config: {}
{% endfor %}
{% endfor %}
```
## Models
The following models are configured by default:
{% for model in default_models %}
- `{{ model.model_id }}`
{% endfor %}
{%- if default_shields %}
## Safety Shields
The following safety shields are configured:
{% for shield in default_shields %}
- `{{ shield.shield_id }}`
{%- endfor %}
{%- endif %}
"""
# Render template with rich-generated table
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
template = env.from_string(template)
return template.render(
name=self.name,
description=self.description,
providers=self.providers,
providers_table=providers_table,
env_vars=self.env_vars,
default_models=self.default_models,
default_shields=self.default_shields,
)
def save_distribution(self, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
# Save build.yaml
build_config = self.build_config()
with open(output_dir / "build.yaml", "w") as f:
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
# Save run.yaml template
# Create a minimal provider config for the template
provider_configs = {
api: [
Provider(
provider_id=f"{provider.lower()}-0",
provider_type=provider,
config={},
)
for provider in providers
]
for api, providers in self.providers.items()
}
run_config = self.run_config(provider_configs)
with open(output_dir / "run.yaml", "w") as f:
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
# Save documentation
docs = self.generate_markdown_docs()
with open(output_dir / f"{self.name}.md", "w") as f:
f.write(docs)
@classmethod
def vllm_distribution(cls) -> "DistributionTemplate":
return cls(
name="remote-vllm",
description="Use (an external) vLLM server for running LLM inference",
providers={
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
},
default_models=[
ModelInput(
model_id="${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}"
),
ModelInput(model_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}"),
],
default_shields=[
ShieldInput(shield_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}")
],
env_vars={
"LLAMA_INFERENCE_VLLM_URL": "URL of the vLLM inference server",
"LLAMA_SAFETY_VLLM_URL": "URL of the vLLM safety server",
"MAX_TOKENS": "Maximum number of tokens for generation",
"LLAMA_INFERENCE_MODEL": "Name of the inference model to use",
"LLAMA_SAFETY_MODEL": "Name of the safety model to use",
},
)
if __name__ == "__main__":
import argparse
import sys
from pathlib import Path
parser = argparse.ArgumentParser(description="Generate a distribution template")
parser.add_argument(
"--type",
choices=["vllm"],
default="vllm",
help="Type of distribution template to generate",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for the distribution files",
)
args = parser.parse_args()
if args.type == "vllm":
template = DistributionTemplate.vllm_distribution()
else:
print(f"Unknown template type: {args.type}", file=sys.stderr)
sys.exit(1)
template.save_distribution(args.output_dir)