mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Write a script to perform the codegen
This commit is contained in:
parent
f38e76ee98
commit
0218e68849
9 changed files with 223 additions and 142 deletions
|
@ -4,19 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
persistence_store: KVStoreConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls):
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||||
db_name="agents_store.db"
|
__distro_dir__=__distro_dir__,
|
||||||
),
|
db_name="agents_store.db",
|
||||||
|
).model_dump(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
KVStoreConfig,
|
KVStoreConfig,
|
||||||
SqliteKVStoreConfig,
|
SqliteKVStoreConfig,
|
||||||
|
@ -16,6 +17,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FaissImplConfig(BaseModel):
|
class FaissImplConfig(BaseModel):
|
||||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
kvstore: KVStoreConfig
|
||||||
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
|
|
||||||
) # Uses SQLite config specific to FAISS storage
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="faiss_store.db",
|
||||||
|
).model_dump(),
|
||||||
|
}
|
||||||
|
|
|
@ -54,11 +54,15 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, dir: str = "runtime", db_name: str = "kvstore.db"):
|
def sample_run_config(
|
||||||
|
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
"type": "sqlite",
|
"type": "sqlite",
|
||||||
"namespace": None,
|
"namespace": None,
|
||||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + f"{dir}/{db_name}" + "}",
|
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
|
||||||
|
+ f"{__distro_dir__}/{db_name}"
|
||||||
|
+ "}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
78
llama_stack/scripts/save_distributions.py
Normal file
78
llama_stack/scripts/save_distributions.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
import importlib
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
|
||||||
|
"""Find immediate subdirectories in the templates folder."""
|
||||||
|
if not templates_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
|
||||||
|
|
||||||
|
return (d for d in templates_dir.iterdir() if d.is_dir())
|
||||||
|
|
||||||
|
|
||||||
|
def process_template(template_dir: Path, progress) -> None:
|
||||||
|
"""Process a single template directory."""
|
||||||
|
progress.print(f"Processing {template_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import the module directly
|
||||||
|
module_name = f"llama_stack.templates.{template_dir.name}"
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Get and save the distribution template
|
||||||
|
if template_func := getattr(module, "get_distribution_template", None):
|
||||||
|
template = template_func()
|
||||||
|
|
||||||
|
template.save_distribution(
|
||||||
|
yaml_output_dir=REPO_ROOT / "distributions" / template.name,
|
||||||
|
doc_output_dir=REPO_ROOT
|
||||||
|
/ "docs/source/getting_started/distributions"
|
||||||
|
/ f"{template.distro_type}_distro",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
progress.print(
|
||||||
|
f"[yellow]Warning: {template_dir.name} has no get_distribution_template function"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
progress.print(f"[red]Error processing {template_dir.name}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
) as progress:
|
||||||
|
template_dirs = list(find_template_dirs(templates_dir))
|
||||||
|
task = progress.add_task(
|
||||||
|
"Processing distribution templates...", total=len(template_dirs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a partial function with the progress bar
|
||||||
|
process_func = partial(process_template, progress=progress)
|
||||||
|
|
||||||
|
# Process templates in parallel
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
# Submit all tasks and wait for completion
|
||||||
|
list(executor.map(process_func, template_dirs))
|
||||||
|
progress.update(task, advance=len(template_dirs))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
7
llama_stack/templates/remote-vllm/__init__.py
Normal file
7
llama_stack/templates/remote-vllm/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .vllm import get_distribution_template # noqa: F401
|
|
@ -78,7 +78,7 @@ inference:
|
||||||
If you are using Conda, you can build and run the Llama Stack server with the following commands:
|
If you are using Conda, you can build and run the Llama Stack server with the following commands:
|
||||||
```bash
|
```bash
|
||||||
cd distributions/remote-vllm
|
cd distributions/remote-vllm
|
||||||
llama stack build --template remote_vllm --image-type conda
|
llama stack build --template remote-vllm --image-type conda
|
||||||
llama stack run run.yaml
|
llama stack run run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
100
llama_stack/templates/remote-vllm/vllm.py
Normal file
100
llama_stack/templates/remote-vllm/vllm.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::vllm"],
|
||||||
|
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="vllm-inference",
|
||||||
|
provider_type="remote::vllm",
|
||||||
|
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||||
|
url="${env.VLLM_URL}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_model = ModelInput(
|
||||||
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
|
provider_id="vllm-inference",
|
||||||
|
)
|
||||||
|
safety_model = ModelInput(
|
||||||
|
model_id="${env.SAFETY_MODEL}",
|
||||||
|
provider_id="vllm-safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="remote-vllm",
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Use (an external) vLLM server for running LLM inference",
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
default_models=[inference_model, safety_model],
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=[inference_model],
|
||||||
|
),
|
||||||
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [
|
||||||
|
inference_provider,
|
||||||
|
Provider(
|
||||||
|
provider_id="vllm-safety",
|
||||||
|
provider_type="remote::vllm",
|
||||||
|
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||||
|
url="${env.SAFETY_VLLM_URL}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
default_models=[
|
||||||
|
inference_model,
|
||||||
|
safety_model,
|
||||||
|
],
|
||||||
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
docker_compose_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"INFERENCE_MODEL": (
|
||||||
|
"meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
"Inference model loaded into the vLLM server",
|
||||||
|
),
|
||||||
|
"VLLM_URL": (
|
||||||
|
"http://host.docker.internal:5100}/v1",
|
||||||
|
"URL of the vLLM server with the main inference model",
|
||||||
|
),
|
||||||
|
"MAX_TOKENS": (
|
||||||
|
"4096",
|
||||||
|
"Maximum number of tokens for generation",
|
||||||
|
),
|
||||||
|
"SAFETY_VLLM_URL": (
|
||||||
|
"http://host.docker.internal:5101/v1",
|
||||||
|
"URL of the vLLM server with the safety model",
|
||||||
|
),
|
||||||
|
"SAFETY_MODEL": (
|
||||||
|
"meta-llama/Llama-Guard-3-1B",
|
||||||
|
"Name of the safety (Llama-Guard) model to use",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
|
@ -9,7 +9,7 @@ from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Literal, Optional, Set, Tuple
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -29,9 +29,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.providers.remote.inference.vllm.config import (
|
|
||||||
VLLMInferenceAdapterConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +67,9 @@ class RunConfigSettings(BaseModel):
|
||||||
|
|
||||||
config_class = instantiate_class_type(config_class)
|
config_class = instantiate_class_type(config_class)
|
||||||
if hasattr(config_class, "sample_run_config"):
|
if hasattr(config_class, "sample_run_config"):
|
||||||
config = config_class.sample_run_config()
|
config = config_class.sample_run_config(
|
||||||
|
__distro_dir__=f"distributions/{name}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
@ -108,6 +107,7 @@ class DistributionTemplate(BaseModel):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
||||||
|
|
||||||
providers: Dict[str, List[str]]
|
providers: Dict[str, List[str]]
|
||||||
run_configs: Dict[str, RunConfigSettings]
|
run_configs: Dict[str, RunConfigSettings]
|
||||||
|
@ -159,140 +159,21 @@ class DistributionTemplate(BaseModel):
|
||||||
default_models=self.default_models,
|
default_models=self.default_models,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_distribution(self, output_dir: Path) -> None:
|
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
for output_dir in [yaml_output_dir, doc_output_dir]:
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
build_config = self.build_config()
|
build_config = self.build_config()
|
||||||
with open(output_dir / "build.yaml", "w") as f:
|
with open(yaml_output_dir / "build.yaml", "w") as f:
|
||||||
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
||||||
|
|
||||||
for yaml_pth, settings in self.run_configs.items():
|
for yaml_pth, settings in self.run_configs.items():
|
||||||
print(f"Generating {yaml_pth}")
|
|
||||||
print(f"Providers: {self.providers}")
|
|
||||||
run_config = settings.run_config(
|
run_config = settings.run_config(
|
||||||
self.name, self.providers, self.docker_image
|
self.name, self.providers, self.docker_image
|
||||||
)
|
)
|
||||||
with open(output_dir / yaml_pth, "w") as f:
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
||||||
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
||||||
|
|
||||||
docs = self.generate_markdown_docs()
|
docs = self.generate_markdown_docs()
|
||||||
with open(output_dir / f"{self.name}.md", "w") as f:
|
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
||||||
f.write(docs)
|
f.write(docs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def vllm_distribution(cls) -> "DistributionTemplate":
|
|
||||||
providers = {
|
|
||||||
"inference": ["remote::vllm"],
|
|
||||||
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
|
||||||
"safety": ["inline::llama-guard"],
|
|
||||||
"agents": ["inline::meta-reference"],
|
|
||||||
"telemetry": ["inline::meta-reference"],
|
|
||||||
}
|
|
||||||
|
|
||||||
inference_provider = Provider(
|
|
||||||
provider_id="vllm-inference",
|
|
||||||
provider_type="remote::vllm",
|
|
||||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
|
||||||
url="${env.VLLM_URL}",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
inference_model = ModelInput(
|
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
|
||||||
provider_id="vllm-inference",
|
|
||||||
)
|
|
||||||
safety_model = ModelInput(
|
|
||||||
model_id="${env.SAFETY_MODEL}",
|
|
||||||
provider_id="vllm-safety",
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name="remote-vllm",
|
|
||||||
description="Use (an external) vLLM server for running LLM inference",
|
|
||||||
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
|
|
||||||
providers=providers,
|
|
||||||
default_models=[inference_model, safety_model],
|
|
||||||
run_configs={
|
|
||||||
"run.yaml": RunConfigSettings(
|
|
||||||
provider_overrides={
|
|
||||||
"inference": [inference_provider],
|
|
||||||
},
|
|
||||||
default_models=[inference_model],
|
|
||||||
),
|
|
||||||
"safety-run.yaml": RunConfigSettings(
|
|
||||||
provider_overrides={
|
|
||||||
"inference": [
|
|
||||||
inference_provider,
|
|
||||||
Provider(
|
|
||||||
provider_id="vllm-safety",
|
|
||||||
provider_type="remote::vllm",
|
|
||||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
|
||||||
url="${env.SAFETY_VLLM_URL}",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
},
|
|
||||||
default_models=[
|
|
||||||
inference_model,
|
|
||||||
safety_model,
|
|
||||||
],
|
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
|
||||||
),
|
|
||||||
},
|
|
||||||
docker_compose_env_vars={
|
|
||||||
"LLAMASTACK_PORT": (
|
|
||||||
"5001",
|
|
||||||
"Port for the Llama Stack distribution server",
|
|
||||||
),
|
|
||||||
"INFERENCE_MODEL": (
|
|
||||||
"meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
"Inference model loaded into the vLLM server",
|
|
||||||
),
|
|
||||||
"VLLM_URL": (
|
|
||||||
"http://host.docker.internal:5100}/v1",
|
|
||||||
"URL of the vLLM server with the main inference model",
|
|
||||||
),
|
|
||||||
"MAX_TOKENS": (
|
|
||||||
"4096",
|
|
||||||
"Maximum number of tokens for generation",
|
|
||||||
),
|
|
||||||
"SAFETY_VLLM_URL": (
|
|
||||||
"http://host.docker.internal:5101/v1",
|
|
||||||
"URL of the vLLM server with the safety model",
|
|
||||||
),
|
|
||||||
"SAFETY_MODEL": (
|
|
||||||
"meta-llama/Llama-Guard-3-1B",
|
|
||||||
"Name of the safety (Llama-Guard) model to use",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Generate a distribution template")
|
|
||||||
parser.add_argument(
|
|
||||||
"--type",
|
|
||||||
choices=["vllm"],
|
|
||||||
default="vllm",
|
|
||||||
help="Type of distribution template to generate",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help="Output directory for the distribution files",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.type == "vllm":
|
|
||||||
template = DistributionTemplate.vllm_distribution()
|
|
||||||
else:
|
|
||||||
print(f"Unknown template type: {args.type}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
template.save_distribution(args.output_dir)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue