Write a script to perform the codegen

This commit is contained in:
Ashwin Bharambe 2024-11-17 14:01:04 -08:00
parent f38e76ee98
commit 0218e68849
9 changed files with 223 additions and 142 deletions

View file

@ -9,7 +9,7 @@ from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Literal, Optional, Set, Tuple
import jinja2
import yaml
@ -29,9 +29,6 @@ from llama_stack.distribution.datatypes import (
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.remote.inference.vllm.config import (
VLLMInferenceAdapterConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -70,7 +67,9 @@ class RunConfigSettings(BaseModel):
config_class = instantiate_class_type(config_class)
if hasattr(config_class, "sample_run_config"):
config = config_class.sample_run_config()
config = config_class.sample_run_config(
__distro_dir__=f"distributions/{name}"
)
else:
config = {}
@ -108,6 +107,7 @@ class DistributionTemplate(BaseModel):
name: str
description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
providers: Dict[str, List[str]]
run_configs: Dict[str, RunConfigSettings]
@ -159,140 +159,21 @@ class DistributionTemplate(BaseModel):
default_models=self.default_models,
)
def save_distribution(self, output_dir: Path) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
for output_dir in [yaml_output_dir, doc_output_dir]:
output_dir.mkdir(parents=True, exist_ok=True)
build_config = self.build_config()
with open(output_dir / "build.yaml", "w") as f:
with open(yaml_output_dir / "build.yaml", "w") as f:
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
for yaml_pth, settings in self.run_configs.items():
print(f"Generating {yaml_pth}")
print(f"Providers: {self.providers}")
run_config = settings.run_config(
self.name, self.providers, self.docker_image
)
with open(output_dir / yaml_pth, "w") as f:
with open(yaml_output_dir / yaml_pth, "w") as f:
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
docs = self.generate_markdown_docs()
with open(output_dir / f"{self.name}.md", "w") as f:
with open(doc_output_dir / f"{self.name}.md", "w") as f:
f.write(docs)
@classmethod
def vllm_distribution(cls) -> "DistributionTemplate":
providers = {
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL}",
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="vllm-safety",
)
return cls(
name="remote-vllm",
description="Use (an external) vLLM server for running LLM inference",
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=[inference_model],
),
"safety-run.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="vllm-safety",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.SAFETY_VLLM_URL}",
),
),
],
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
},
docker_compose_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the vLLM server",
),
"VLLM_URL": (
"http://host.docker.internal:5100}/v1",
"URL of the vLLM server with the main inference model",
),
"MAX_TOKENS": (
"4096",
"Maximum number of tokens for generation",
),
"SAFETY_VLLM_URL": (
"http://host.docker.internal:5101/v1",
"URL of the vLLM server with the safety model",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Name of the safety (Llama-Guard) model to use",
),
},
)
if __name__ == "__main__":
import argparse
import sys
from pathlib import Path
parser = argparse.ArgumentParser(description="Generate a distribution template")
parser.add_argument(
"--type",
choices=["vllm"],
default="vllm",
help="Type of distribution template to generate",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for the distribution files",
)
args = parser.parse_args()
if args.type == "vllm":
template = DistributionTemplate.vllm_distribution()
else:
print(f"Unknown template type: {args.type}", file=sys.stderr)
sys.exit(1)
template.save_distribution(args.output_dir)