mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Start auto-generating { build, run, doc.md } for distributions
This commit is contained in:
parent
20bf2f50c2
commit
cfa913fdd5
11 changed files with 362 additions and 23 deletions
|
@ -13,20 +13,15 @@ apis:
|
||||||
- safety
|
- safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: ollama0
|
- provider_id: ollama
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:14343
|
url: ${env.OLLAMA_URL:http://127.0.0.1:11434}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
model: Llama-Guard-3-1B
|
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
- provider_id: meta1
|
|
||||||
provider_type: inline::prompt-guard
|
|
||||||
config:
|
|
||||||
model: Prompt-Guard-86M
|
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -43,3 +38,10 @@ providers:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
models:
|
||||||
|
- model_id: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
||||||
|
provider_id: ollama
|
||||||
|
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
|
provider_id: ollama
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
|
|
|
@ -13,20 +13,15 @@ apis:
|
||||||
- safety
|
- safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: ollama0
|
- provider_id: ollama
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:14343
|
url: ${env.LLAMA_INFERENCE_OLLAMA_URL:http://127.0.0.1:11434}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
config:
|
config:
|
||||||
model: Llama-Guard-3-1B
|
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
- provider_id: meta1
|
|
||||||
provider_type: inline::prompt-guard
|
|
||||||
config:
|
|
||||||
model: Prompt-Guard-86M
|
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -43,3 +38,10 @@ providers:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
models:
|
||||||
|
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
||||||
|
provider_id: ollama
|
||||||
|
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
|
provider_id: ollama
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
# NOTE: replace with "localhost" if you are running in "host" network mode
|
# NOTE: replace with "localhost" if you are running in "host" network mode
|
||||||
url: ${env.LLAMA_INFERENCE_VLLM_URL:http://host.docker.internal:5100/v1}
|
url: ${env.VLLM_URL:http://host.docker.internal:5100/v1}
|
||||||
max_tokens: ${env.MAX_TOKENS:4096}
|
max_tokens: ${env.MAX_TOKENS:4096}
|
||||||
api_token: fake
|
api_token: fake
|
||||||
# serves safety llama_guard model
|
# serves safety llama_guard model
|
||||||
|
@ -24,7 +24,7 @@ providers:
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
# NOTE: replace with "localhost" if you are running in "host" network mode
|
# NOTE: replace with "localhost" if you are running in "host" network mode
|
||||||
url: ${env.LLAMA_SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
|
url: ${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}
|
||||||
max_tokens: ${env.MAX_TOKENS:4096}
|
max_tokens: ${env.MAX_TOKENS:4096}
|
||||||
api_token: fake
|
api_token: fake
|
||||||
memory:
|
memory:
|
||||||
|
@ -34,7 +34,7 @@ providers:
|
||||||
kvstore:
|
kvstore:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/faiss_store.db"
|
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/faiss_store.db"
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
@ -50,7 +50,7 @@ providers:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/agents_store.db"
|
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db"
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -58,11 +58,11 @@ providers:
|
||||||
metadata_store:
|
metadata_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: "${env.SQLITE_STORE_DIR:/home/ashwin/.llama/distributions/remote-vllm}/registry.db"
|
db_path: "${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db"
|
||||||
models:
|
models:
|
||||||
- model_id: ${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}
|
- model_id: ${env.INFERENCE_MODEL:Llama3.1-8B-Instruct}
|
||||||
provider_id: vllm-0
|
provider_id: vllm-0
|
||||||
- model_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
provider_id: vllm-1
|
provider_id: vllm-1
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}
|
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
|
||||||
|
|
|
@ -313,7 +313,8 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
else:
|
else:
|
||||||
value = default_val
|
value = default_val
|
||||||
|
|
||||||
return value
|
# expand "~" from the values
|
||||||
|
return os.path.expanduser(value)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return re.sub(pattern, get_env_var, config)
|
return re.sub(pattern, get_env_var, config)
|
||||||
|
|
|
@ -12,3 +12,11 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls):
|
||||||
|
return {
|
||||||
|
"persistence_store": SqliteKVStoreConfig.sample_dict(
|
||||||
|
db_name="agents_store.db"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
@ -34,6 +34,16 @@ class VLLMConfig(BaseModel):
|
||||||
default=0.3,
|
default=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls):
|
||||||
|
return {
|
||||||
|
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||||
|
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
|
||||||
|
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||||
|
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
|
||||||
|
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
|
||||||
|
}
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
|
|
|
@ -11,3 +11,9 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class LlamaGuardConfig(BaseModel):
|
class LlamaGuardConfig(BaseModel):
|
||||||
excluded_categories: List[str] = []
|
excluded_categories: List[str] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls):
|
||||||
|
return {
|
||||||
|
"excluded_categories": [],
|
||||||
|
}
|
||||||
|
|
|
@ -24,3 +24,12 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
default="fake",
|
default="fake",
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls):
|
||||||
|
# TODO: we may need two modes, one for conda and one for docker
|
||||||
|
return {
|
||||||
|
"url": "${env.VLLM_URL:http://host.docker.internal:5100/v1}",
|
||||||
|
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||||
|
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||||
|
}
|
||||||
|
|
|
@ -36,6 +36,15 @@ class RedisKVStoreConfig(CommonConfig):
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return f"redis://{self.host}:{self.port}"
|
return f"redis://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls):
|
||||||
|
return {
|
||||||
|
"type": "redis",
|
||||||
|
"namespace": None,
|
||||||
|
"host": "${env.REDIS_HOST:localhost}",
|
||||||
|
"port": "${env.REDIS_PORT:6379}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SqliteKVStoreConfig(CommonConfig):
|
class SqliteKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||||
|
@ -44,6 +53,14 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
description="File path for the sqlite database",
|
description="File path for the sqlite database",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls, db_name: str = "kvstore.db"):
|
||||||
|
return {
|
||||||
|
"type": "sqlite",
|
||||||
|
"namespace": None,
|
||||||
|
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/runtime/" + db_name + "}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreConfig(CommonConfig):
|
class PostgresKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||||
|
@ -54,6 +71,19 @@ class PostgresKVStoreConfig(CommonConfig):
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
table_name: str = "llamastack_kvstore"
|
table_name: str = "llamastack_kvstore"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_dict(cls, table_name: str = "llamastack_kvstore"):
|
||||||
|
return {
|
||||||
|
"type": "postgres",
|
||||||
|
"namespace": None,
|
||||||
|
"host": "${env.POSTGRES_HOST:localhost}",
|
||||||
|
"port": "${env.POSTGRES_PORT:5432}",
|
||||||
|
"db": "${env.POSTGRES_DB}",
|
||||||
|
"user": "${env.POSTGRES_USER}",
|
||||||
|
"password": "${env.POSTGRES_PASSWORD}",
|
||||||
|
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@field_validator("table_name")
|
@field_validator("table_name")
|
||||||
def validate_table_name(cls, v: str) -> str:
|
def validate_table_name(cls, v: str) -> str:
|
||||||
|
|
5
llama_stack/templates/__init__.py
Normal file
5
llama_stack/templates/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
266
llama_stack/templates/template.py
Normal file
266
llama_stack/templates/template.py
Normal file
|
@ -0,0 +1,266 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
BuildConfig,
|
||||||
|
DistributionSpec,
|
||||||
|
KVStoreConfig,
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
StackRunConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionTemplate(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a Llama Stack distribution instance that can generate configuration
|
||||||
|
and documentation files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
providers: Dict[str, List[str]]
|
||||||
|
default_models: List[ModelInput]
|
||||||
|
default_shields: Optional[List[ShieldInput]] = None
|
||||||
|
|
||||||
|
# Optional configuration
|
||||||
|
metadata_store: Optional[KVStoreConfig] = None
|
||||||
|
env_vars: Optional[Dict[str, str]] = None
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution_spec(self) -> DistributionSpec:
|
||||||
|
return DistributionSpec(
|
||||||
|
description=self.description,
|
||||||
|
docker_image=self.docker_image,
|
||||||
|
providers=self.providers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_config(self) -> BuildConfig:
|
||||||
|
return BuildConfig(
|
||||||
|
name=self.name,
|
||||||
|
distribution_spec=self.distribution_spec,
|
||||||
|
image_type="conda", # default to conda, can be overridden
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_config(self, provider_configs: Dict[str, List[Provider]]) -> StackRunConfig:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Get unique set of APIs from providers
|
||||||
|
apis: Set[str] = set(self.providers.keys())
|
||||||
|
|
||||||
|
return StackRunConfig(
|
||||||
|
image_name=self.name,
|
||||||
|
docker_image=self.docker_image,
|
||||||
|
built_at=datetime.now(),
|
||||||
|
apis=list(apis),
|
||||||
|
providers=provider_configs,
|
||||||
|
metadata_store=self.metadata_store,
|
||||||
|
models=self.default_models,
|
||||||
|
shields=self.default_shields or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_markdown_docs(self) -> str:
|
||||||
|
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
|
||||||
|
# First generate the providers table using rich
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output, force_terminal=False)
|
||||||
|
|
||||||
|
table = Table(title="Provider Configuration", show_header=True)
|
||||||
|
table.add_column("API", style="bold")
|
||||||
|
table.add_column("Provider(s)")
|
||||||
|
|
||||||
|
for api, providers in sorted(self.providers.items()):
|
||||||
|
table.add_row(api, ", ".join(f"`{p}`" for p in providers))
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
providers_table = output.getvalue()
|
||||||
|
|
||||||
|
# Main documentation template
|
||||||
|
template = """# {{ name }} Distribution
|
||||||
|
|
||||||
|
{{ description }}
|
||||||
|
|
||||||
|
## Provider Configuration
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{%- if env_vars %}
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, description in env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }}
|
||||||
|
{% endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
### Using Docker Compose
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cd distributions/{{ name }}
|
||||||
|
$ docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Configuration
|
||||||
|
|
||||||
|
You can also configure the distribution manually by creating a `run.yaml` file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '2'
|
||||||
|
image_name: {{ name }}
|
||||||
|
apis:
|
||||||
|
{% for api in providers.keys() %}
|
||||||
|
- {{ api }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
providers:
|
||||||
|
{% for api, provider_list in providers.items() %}
|
||||||
|
{{ api }}:
|
||||||
|
{% for provider in provider_list %}
|
||||||
|
- provider_id: {{ provider.lower() }}-0
|
||||||
|
provider_type: {{ provider }}
|
||||||
|
config: {}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
The following models are configured by default:
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }}`
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{%- if default_shields %}
|
||||||
|
|
||||||
|
## Safety Shields
|
||||||
|
|
||||||
|
The following safety shields are configured:
|
||||||
|
{% for shield in default_shields %}
|
||||||
|
- `{{ shield.shield_id }}`
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
"""
|
||||||
|
# Render template with rich-generated table
|
||||||
|
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
||||||
|
template = env.from_string(template)
|
||||||
|
return template.render(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
providers=self.providers,
|
||||||
|
providers_table=providers_table,
|
||||||
|
env_vars=self.env_vars,
|
||||||
|
default_models=self.default_models,
|
||||||
|
default_shields=self.default_shields,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_distribution(self, output_dir: Path) -> None:
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save build.yaml
|
||||||
|
build_config = self.build_config()
|
||||||
|
with open(output_dir / "build.yaml", "w") as f:
|
||||||
|
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
||||||
|
|
||||||
|
# Save run.yaml template
|
||||||
|
# Create a minimal provider config for the template
|
||||||
|
provider_configs = {
|
||||||
|
api: [
|
||||||
|
Provider(
|
||||||
|
provider_id=f"{provider.lower()}-0",
|
||||||
|
provider_type=provider,
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
for provider in providers
|
||||||
|
]
|
||||||
|
for api, providers in self.providers.items()
|
||||||
|
}
|
||||||
|
run_config = self.run_config(provider_configs)
|
||||||
|
with open(output_dir / "run.yaml", "w") as f:
|
||||||
|
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
||||||
|
|
||||||
|
# Save documentation
|
||||||
|
docs = self.generate_markdown_docs()
|
||||||
|
with open(output_dir / f"{self.name}.md", "w") as f:
|
||||||
|
f.write(docs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def vllm_distribution(cls) -> "DistributionTemplate":
|
||||||
|
return cls(
|
||||||
|
name="remote-vllm",
|
||||||
|
description="Use (an external) vLLM server for running LLM inference",
|
||||||
|
providers={
|
||||||
|
"inference": ["remote::vllm"],
|
||||||
|
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
},
|
||||||
|
default_models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id="${env.LLAMA_INFERENCE_MODEL:Llama3.1-8B-Instruct}"
|
||||||
|
),
|
||||||
|
ModelInput(model_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}"),
|
||||||
|
],
|
||||||
|
default_shields=[
|
||||||
|
ShieldInput(shield_id="${env.LLAMA_SAFETY_MODEL:Llama-Guard-3-1B}")
|
||||||
|
],
|
||||||
|
env_vars={
|
||||||
|
"LLAMA_INFERENCE_VLLM_URL": "URL of the vLLM inference server",
|
||||||
|
"LLAMA_SAFETY_VLLM_URL": "URL of the vLLM safety server",
|
||||||
|
"MAX_TOKENS": "Maximum number of tokens for generation",
|
||||||
|
"LLAMA_INFERENCE_MODEL": "Name of the inference model to use",
|
||||||
|
"LLAMA_SAFETY_MODEL": "Name of the safety model to use",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Generate a distribution template")
|
||||||
|
parser.add_argument(
|
||||||
|
"--type",
|
||||||
|
choices=["vllm"],
|
||||||
|
default="vllm",
|
||||||
|
help="Type of distribution template to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Output directory for the distribution files",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.type == "vllm":
|
||||||
|
template = DistributionTemplate.vllm_distribution()
|
||||||
|
else:
|
||||||
|
print(f"Unknown template type: {args.type}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
template.save_distribution(args.output_dir)
|
Loading…
Add table
Add a link
Reference in a new issue