mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:09:47 +00:00
Write a script to perform the codegen
This commit is contained in:
parent
f38e76ee98
commit
0218e68849
9 changed files with 223 additions and 142 deletions
|
|
@ -9,7 +9,7 @@ from datetime import datetime
|
|||
from io import StringIO
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
import jinja2
|
||||
import yaml
|
||||
|
|
@ -29,9 +29,6 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.remote.inference.vllm.config import (
|
||||
VLLMInferenceAdapterConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
|
|
@ -70,7 +67,9 @@ class RunConfigSettings(BaseModel):
|
|||
|
||||
config_class = instantiate_class_type(config_class)
|
||||
if hasattr(config_class, "sample_run_config"):
|
||||
config = config_class.sample_run_config()
|
||||
config = config_class.sample_run_config(
|
||||
__distro_dir__=f"distributions/{name}"
|
||||
)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
|
|
@ -108,6 +107,7 @@ class DistributionTemplate(BaseModel):
|
|||
|
||||
name: str
|
||||
description: str
|
||||
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
||||
|
||||
providers: Dict[str, List[str]]
|
||||
run_configs: Dict[str, RunConfigSettings]
|
||||
|
|
@ -159,140 +159,21 @@ class DistributionTemplate(BaseModel):
|
|||
default_models=self.default_models,
|
||||
)
|
||||
|
||||
def save_distribution(self, output_dir: Path) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
||||
for output_dir in [yaml_output_dir, doc_output_dir]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
build_config = self.build_config()
|
||||
with open(output_dir / "build.yaml", "w") as f:
|
||||
with open(yaml_output_dir / "build.yaml", "w") as f:
|
||||
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
for yaml_pth, settings in self.run_configs.items():
|
||||
print(f"Generating {yaml_pth}")
|
||||
print(f"Providers: {self.providers}")
|
||||
run_config = settings.run_config(
|
||||
self.name, self.providers, self.docker_image
|
||||
)
|
||||
with open(output_dir / yaml_pth, "w") as f:
|
||||
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
docs = self.generate_markdown_docs()
|
||||
with open(output_dir / f"{self.name}.md", "w") as f:
|
||||
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
||||
f.write(docs)
|
||||
|
||||
@classmethod
|
||||
def vllm_distribution(cls) -> "DistributionTemplate":
|
||||
providers = {
|
||||
"inference": ["remote::vllm"],
|
||||
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="vllm-inference",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.VLLM_URL}",
|
||||
),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="vllm-inference",
|
||||
)
|
||||
safety_model = ModelInput(
|
||||
model_id="${env.SAFETY_MODEL}",
|
||||
provider_id="vllm-safety",
|
||||
)
|
||||
|
||||
return cls(
|
||||
name="remote-vllm",
|
||||
description="Use (an external) vLLM server for running LLM inference",
|
||||
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
|
||||
providers=providers,
|
||||
default_models=[inference_model, safety_model],
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
},
|
||||
default_models=[inference_model],
|
||||
),
|
||||
"safety-run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [
|
||||
inference_provider,
|
||||
Provider(
|
||||
provider_id="vllm-safety",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.SAFETY_VLLM_URL}",
|
||||
),
|
||||
),
|
||||
],
|
||||
},
|
||||
default_models=[
|
||||
inference_model,
|
||||
safety_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
),
|
||||
},
|
||||
docker_compose_env_vars={
|
||||
"LLAMASTACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Inference model loaded into the vLLM server",
|
||||
),
|
||||
"VLLM_URL": (
|
||||
"http://host.docker.internal:5100}/v1",
|
||||
"URL of the vLLM server with the main inference model",
|
||||
),
|
||||
"MAX_TOKENS": (
|
||||
"4096",
|
||||
"Maximum number of tokens for generation",
|
||||
),
|
||||
"SAFETY_VLLM_URL": (
|
||||
"http://host.docker.internal:5101/v1",
|
||||
"URL of the vLLM server with the safety model",
|
||||
),
|
||||
"SAFETY_MODEL": (
|
||||
"meta-llama/Llama-Guard-3-1B",
|
||||
"Name of the safety (Llama-Guard) model to use",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate a distribution template")
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
choices=["vllm"],
|
||||
default="vllm",
|
||||
help="Type of distribution template to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for the distribution files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.type == "vllm":
|
||||
template = DistributionTemplate.vllm_distribution()
|
||||
else:
|
||||
print(f"Unknown template type: {args.type}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
template.save_distribution(args.output_dir)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue