Write a script to perform the codegen

This commit is contained in:
Ashwin Bharambe 2024-11-17 14:01:04 -08:00
parent f38e76ee98
commit 0218e68849
9 changed files with 223 additions and 142 deletions

View file

@ -4,19 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
persistence_store: KVStoreConfig
@classmethod
def sample_run_config(cls):
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"persistence_store": SqliteKVStoreConfig.sample_run_config(
db_name="agents_store.db"
),
__distro_dir__=__distro_dir__,
db_name="agents_store.db",
).model_dump(),
}

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -16,6 +17,13 @@ from llama_stack.providers.utils.kvstore.config import (
@json_schema_type
class FaissImplConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
) # Uses SQLite config specific to FAISS storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="faiss_store.db",
).model_dump(),
}

View file

@ -54,11 +54,15 @@ class SqliteKVStoreConfig(CommonConfig):
)
@classmethod
def sample_run_config(cls, dir: str = "runtime", db_name: str = "kvstore.db"):
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + f"{dir}/{db_name}" + "}",
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
+ f"{__distro_dir__}/{db_name}"
+ "}",
}

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import importlib
from functools import partial
from pathlib import Path
from typing import Iterator
from rich.progress import Progress, SpinnerColumn, TextColumn
REPO_ROOT = Path(__file__).parent.parent.parent
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
return (d for d in templates_dir.iterdir() if d.is_dir())
def process_template(template_dir: Path, progress) -> None:
"""Process a single template directory."""
progress.print(f"Processing {template_dir.name}")
try:
# Import the module directly
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)
# Get and save the distribution template
if template_func := getattr(module, "get_distribution_template", None):
template = template_func()
template.save_distribution(
yaml_output_dir=REPO_ROOT / "distributions" / template.name,
doc_output_dir=REPO_ROOT
/ "docs/source/getting_started/distributions"
/ f"{template.distro_type}_distro",
)
else:
progress.print(
f"[yellow]Warning: {template_dir.name} has no get_distribution_template function"
)
except Exception as e:
progress.print(f"[red]Error processing {template_dir.name}: {str(e)}")
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
) as progress:
template_dirs = list(find_template_dirs(templates_dir))
task = progress.add_task(
"Processing distribution templates...", total=len(template_dirs)
)
# Create a partial function with the progress bar
process_func = partial(process_template, progress=progress)
# Process templates in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit all tasks and wait for completion
list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vllm import get_distribution_template # noqa: F401

View file

@ -78,7 +78,7 @@ inference:
If you are using Conda, you can build and run the Llama Stack server with the following commands:
```bash
cd distributions/remote-vllm
llama stack build --template remote_vllm --image-type conda
llama stack build --template remote-vllm --image-type conda
llama stack run run.yaml
```

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL}",
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="vllm-safety",
)
return DistributionTemplate(
name="remote-vllm",
distro_type="self_hosted",
description="Use (an external) vLLM server for running LLM inference",
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=[inference_model],
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="vllm-safety",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.SAFETY_VLLM_URL}",
),
),
],
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
},
docker_compose_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the vLLM server",
),
"VLLM_URL": (
"http://host.docker.internal:5100}/v1",
"URL of the vLLM server with the main inference model",
),
"MAX_TOKENS": (
"4096",
"Maximum number of tokens for generation",
),
"SAFETY_VLLM_URL": (
"http://host.docker.internal:5101/v1",
"URL of the vLLM server with the safety model",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Name of the safety (Llama-Guard) model to use",
),
},
)

View file

@ -9,7 +9,7 @@ from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Literal, Optional, Set, Tuple
import jinja2
import yaml
@ -29,9 +29,6 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.remote.inference.vllm.config import (
VLLMInferenceAdapterConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -70,7 +67,9 @@ class RunConfigSettings(BaseModel):
config_class = instantiate_class_type(config_class)
if hasattr(config_class, "sample_run_config"):
config = config_class.sample_run_config()
config = config_class.sample_run_config(
__distro_dir__=f"distributions/{name}"
)
else:
config = {}
@ -108,6 +107,7 @@ class DistributionTemplate(BaseModel):
name: str
description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
providers: Dict[str, List[str]]
run_configs: Dict[str, RunConfigSettings]
@ -159,140 +159,21 @@ class DistributionTemplate(BaseModel):
default_models=self.default_models,
)
def save_distribution(self, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
for output_dir in [yaml_output_dir, doc_output_dir]:
output_dir.mkdir(parents=True, exist_ok=True)
build_config = self.build_config()
with open(output_dir / "build.yaml", "w") as f:
with open(yaml_output_dir / "build.yaml", "w") as f:
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
for yaml_pth, settings in self.run_configs.items():
print(f"Generating {yaml_pth}")
print(f"Providers: {self.providers}")
run_config = settings.run_config(
self.name, self.providers, self.docker_image
)
with open(output_dir / yaml_pth, "w") as f:
with open(yaml_output_dir / yaml_pth, "w") as f:
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
docs = self.generate_markdown_docs()
with open(output_dir / f"{self.name}.md", "w") as f:
with open(doc_output_dir / f"{self.name}.md", "w") as f:
f.write(docs)
@classmethod
def vllm_distribution(cls) -> "DistributionTemplate":
providers = {
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL}",
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="vllm-safety",
)
return cls(
name="remote-vllm",
description="Use (an external) vLLM server for running LLM inference",
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=[inference_model],
),
"safety-run.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="vllm-safety",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.SAFETY_VLLM_URL}",
),
),
],
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
},
docker_compose_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the vLLM server",
),
"VLLM_URL": (
"http://host.docker.internal:5100}/v1",
"URL of the vLLM server with the main inference model",
),
"MAX_TOKENS": (
"4096",
"Maximum number of tokens for generation",
),
"SAFETY_VLLM_URL": (
"http://host.docker.internal:5101/v1",
"URL of the vLLM server with the safety model",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Name of the safety (Llama-Guard) model to use",
),
},
)
if __name__ == "__main__":
import argparse
import sys
from pathlib import Path
parser = argparse.ArgumentParser(description="Generate a distribution template")
parser.add_argument(
"--type",
choices=["vllm"],
default="vllm",
help="Type of distribution template to generate",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for the distribution files",
)
args = parser.parse_args()
if args.type == "vllm":
template = DistributionTemplate.vllm_distribution()
else:
print(f"Unknown template type: {args.type}", file=sys.stderr)
sys.exit(1)
template.save_distribution(args.output_dir)