mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:12:36 +00:00
Adding docker-compose.yaml, starting to simplify
This commit is contained in:
parent
e4509cb568
commit
f38e76ee98
14 changed files with 516 additions and 386 deletions
|
|
@ -9,7 +9,7 @@ from datetime import datetime
|
|||
from io import StringIO
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import jinja2
|
||||
import yaml
|
||||
|
|
@ -22,7 +22,6 @@ from llama_stack.distribution.datatypes import (
|
|||
Api,
|
||||
BuildConfig,
|
||||
DistributionSpec,
|
||||
KVStoreConfig,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
|
|
@ -33,53 +32,26 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|||
from llama_stack.providers.remote.inference.vllm.config import (
|
||||
VLLMInferenceAdapterConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.docker.service_config import DockerComposeServiceConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class DistributionTemplate(BaseModel):
|
||||
"""
|
||||
Represents a Llama Stack distribution instance that can generate configuration
|
||||
and documentation files.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
providers: Dict[str, List[str]]
|
||||
run_config_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
|
||||
compose_config_overrides: Dict[str, Dict[str, DockerComposeServiceConfig]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
class RunConfigSettings(BaseModel):
|
||||
provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
|
||||
default_models: List[ModelInput]
|
||||
default_shields: Optional[List[ShieldInput]] = None
|
||||
|
||||
# Optional configuration
|
||||
metadata_store: Optional[KVStoreConfig] = None
|
||||
docker_compose_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
@property
|
||||
def distribution_spec(self) -> DistributionSpec:
|
||||
return DistributionSpec(
|
||||
description=self.description,
|
||||
docker_image=self.docker_image,
|
||||
providers=self.providers,
|
||||
)
|
||||
|
||||
def build_config(self) -> BuildConfig:
|
||||
return BuildConfig(
|
||||
name=self.name,
|
||||
distribution_spec=self.distribution_spec,
|
||||
image_type="conda", # default to conda, can be overridden
|
||||
)
|
||||
|
||||
def run_config(self) -> StackRunConfig:
|
||||
def run_config(
|
||||
self,
|
||||
name: str,
|
||||
providers: Dict[str, List[str]],
|
||||
docker_image: Optional[str] = None,
|
||||
) -> StackRunConfig:
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
provider_configs = {}
|
||||
for api_str, provider_types in self.providers.items():
|
||||
if providers := self.run_config_overrides.get(api_str):
|
||||
provider_configs[api_str] = providers
|
||||
for api_str, provider_types in providers.items():
|
||||
if api_providers := self.provider_overrides.get(api_str):
|
||||
provider_configs[api_str] = api_providers
|
||||
continue
|
||||
|
||||
provider_type = provider_types[0]
|
||||
|
|
@ -111,83 +83,53 @@ class DistributionTemplate(BaseModel):
|
|||
]
|
||||
|
||||
# Get unique set of APIs from providers
|
||||
apis: Set[str] = set(self.providers.keys())
|
||||
apis: Set[str] = set(providers.keys())
|
||||
|
||||
return StackRunConfig(
|
||||
image_name=self.name,
|
||||
docker_image=self.docker_image,
|
||||
image_name=name,
|
||||
docker_image=docker_image,
|
||||
built_at=datetime.now(),
|
||||
apis=list(apis),
|
||||
providers=provider_configs,
|
||||
metadata_store=self.metadata_store,
|
||||
metadata_store=SqliteKVStoreConfig.sample_run_config(
|
||||
dir=f"distributions/{name}",
|
||||
db_name="registry.db",
|
||||
),
|
||||
models=self.default_models,
|
||||
shields=self.default_shields or [],
|
||||
)
|
||||
|
||||
def docker_compose_config(self) -> Dict[str, Any]:
|
||||
services = {}
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
# Add provider services based on their sample_compose_config
|
||||
for api_str, api_providers in self.providers.items():
|
||||
if overrides := self.compose_config_overrides.get(api_str):
|
||||
services |= overrides
|
||||
continue
|
||||
class DistributionTemplate(BaseModel):
|
||||
"""
|
||||
Represents a Llama Stack distribution instance that can generate configuration
|
||||
and documentation files.
|
||||
"""
|
||||
|
||||
# only look at the first provider to get the compose config for now
|
||||
# we may want to use `docker compose profiles` in the future
|
||||
provider_type = api_providers[0]
|
||||
provider_id = provider_type.split("::")[-1]
|
||||
api = Api(api_str)
|
||||
if provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Unknown provider type: {provider_type} for API: {api_str}"
|
||||
)
|
||||
name: str
|
||||
description: str
|
||||
|
||||
config_class = provider_registry[api][provider_type].config_class
|
||||
assert (
|
||||
config_class is not None
|
||||
), f"No config class for provider type: {provider_type} for API: {api_str}"
|
||||
providers: Dict[str, List[str]]
|
||||
run_configs: Dict[str, RunConfigSettings]
|
||||
template_path: Path
|
||||
|
||||
config_class = instantiate_class_type(config_class)
|
||||
if not hasattr(config_class, "sample_docker_compose_config"):
|
||||
continue
|
||||
# Optional configuration
|
||||
docker_compose_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
compose_config = config_class.sample_docker_compose_config()
|
||||
services[provider_id] = compose_config
|
||||
default_models: Optional[List[ModelInput]] = None
|
||||
|
||||
port = "${LLAMASTACK_PORT:-5001}"
|
||||
# Add main llamastack service
|
||||
llamastack_config = DockerComposeServiceConfig(
|
||||
image=f"llamastack/distribution-{self.name}:latest",
|
||||
depends_on=list(services.keys()),
|
||||
volumes=[
|
||||
"~/.llama:/root/.llama",
|
||||
f"~/local/llama-stack/distributions/{self.name}/run.yaml:/root/llamastack-run-{self.name}.yaml",
|
||||
],
|
||||
ports=[f"{port}:{port}"],
|
||||
environment={
|
||||
k: v[0] for k, v in (self.docker_compose_env_vars or {}).items()
|
||||
},
|
||||
entrypoint=(
|
||||
f'bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-{self.name}.yaml --port {port}"'
|
||||
def build_config(self) -> BuildConfig:
|
||||
return BuildConfig(
|
||||
name=self.name,
|
||||
distribution_spec=DistributionSpec(
|
||||
description=self.description,
|
||||
docker_image=self.docker_image,
|
||||
providers=self.providers,
|
||||
),
|
||||
deploy={
|
||||
"restart_policy": {
|
||||
"condition": "on-failure",
|
||||
"delay": "3s",
|
||||
"max_attempts": 5,
|
||||
"window": "60s",
|
||||
}
|
||||
},
|
||||
image_type="conda", # default to conda, can be overridden
|
||||
)
|
||||
|
||||
services["llamastack"] = llamastack_config
|
||||
return {
|
||||
"services": {k: v.model_dump() for k, v in services.items()},
|
||||
"volumes": {service_name: None for service_name in services.keys()},
|
||||
}
|
||||
|
||||
def generate_markdown_docs(self) -> str:
|
||||
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
|
||||
# First generate the providers table using rich
|
||||
|
|
@ -204,53 +146,7 @@ class DistributionTemplate(BaseModel):
|
|||
console.print(table)
|
||||
providers_table = output.getvalue()
|
||||
|
||||
# Main documentation template
|
||||
template = """# {{ name }} Distribution
|
||||
|
||||
{{ description }}
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{%- if env_vars %}
|
||||
## Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (value, description) in docker_compose_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }}
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Using Docker Compose
|
||||
|
||||
```bash
|
||||
$ cd distributions/{{ name }}
|
||||
$ docker compose up
|
||||
```
|
||||
|
||||
## Models
|
||||
|
||||
The following models are configured by default:
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }}`
|
||||
{% endfor %}
|
||||
|
||||
{%- if default_shields %}
|
||||
|
||||
## Safety Shields
|
||||
|
||||
The following safety shields are configured:
|
||||
{% for shield in default_shields %}
|
||||
- `{{ shield.shield_id }}`
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
"""
|
||||
template = self.template_path.read_text()
|
||||
# Render template with rich-generated table
|
||||
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
||||
template = env.from_string(template)
|
||||
|
|
@ -261,7 +157,6 @@ The following safety shields are configured:
|
|||
providers_table=providers_table,
|
||||
docker_compose_env_vars=self.docker_compose_env_vars,
|
||||
default_models=self.default_models,
|
||||
default_shields=self.default_shields,
|
||||
)
|
||||
|
||||
def save_distribution(self, output_dir: Path) -> None:
|
||||
|
|
@ -271,19 +166,14 @@ The following safety shields are configured:
|
|||
with open(output_dir / "build.yaml", "w") as f:
|
||||
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
run_config = self.run_config()
|
||||
serialized = run_config.model_dump()
|
||||
with open(output_dir / "run.yaml", "w") as f:
|
||||
yaml.safe_dump(serialized, f, sort_keys=False)
|
||||
|
||||
# serialized_str = yaml.dump(serialized, sort_keys=False)
|
||||
# env_vars = set()
|
||||
# for match in re.finditer(r"\${env\.([A-Za-z0-9_-]+)}", serialized_str):
|
||||
# env_vars.add(match.group(1))
|
||||
|
||||
docker_compose = self.docker_compose_config()
|
||||
with open(output_dir / "compose.yaml", "w") as f:
|
||||
yaml.safe_dump(docker_compose, f, sort_keys=False, default_flow_style=False)
|
||||
for yaml_pth, settings in self.run_configs.items():
|
||||
print(f"Generating {yaml_pth}")
|
||||
print(f"Providers: {self.providers}")
|
||||
run_config = settings.run_config(
|
||||
self.name, self.providers, self.docker_image
|
||||
)
|
||||
with open(output_dir / yaml_pth, "w") as f:
|
||||
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
|
||||
|
||||
docs = self.generate_markdown_docs()
|
||||
with open(output_dir / f"{self.name}.md", "w") as f:
|
||||
|
|
@ -291,87 +181,89 @@ The following safety shields are configured:
|
|||
|
||||
@classmethod
|
||||
def vllm_distribution(cls) -> "DistributionTemplate":
|
||||
providers = {
|
||||
"inference": ["remote::vllm"],
|
||||
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="vllm-inference",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.VLLM_URL}",
|
||||
),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="vllm-inference",
|
||||
)
|
||||
safety_model = ModelInput(
|
||||
model_id="${env.SAFETY_MODEL}",
|
||||
provider_id="vllm-safety",
|
||||
)
|
||||
|
||||
return cls(
|
||||
name="remote-vllm",
|
||||
description="Use (an external) vLLM server for running LLM inference",
|
||||
providers={
|
||||
"inference": ["remote::vllm"],
|
||||
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
},
|
||||
run_config_overrides={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="vllm-0",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.VLLM_URL:http://host.docker.internal:5100/v1}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="vllm-1",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}",
|
||||
),
|
||||
),
|
||||
]
|
||||
},
|
||||
compose_config_overrides={
|
||||
"inference": {
|
||||
"vllm-0": VLLMInferenceAdapterConfig.sample_docker_compose_config(
|
||||
port=5100,
|
||||
cuda_visible_devices="0",
|
||||
model="${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
),
|
||||
"vllm-1": VLLMInferenceAdapterConfig.sample_docker_compose_config(
|
||||
port=5100,
|
||||
cuda_visible_devices="1",
|
||||
model="${env.SAFETY_MODEL:Llama-Guard-3-1B}",
|
||||
),
|
||||
}
|
||||
},
|
||||
default_models=[
|
||||
ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
provider_id="vllm-0",
|
||||
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
|
||||
providers=providers,
|
||||
default_models=[inference_model, safety_model],
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
},
|
||||
default_models=[inference_model],
|
||||
),
|
||||
ModelInput(
|
||||
model_id="${env.SAFETY_MODEL:Llama-Guard-3-1B}",
|
||||
provider_id="vllm-1",
|
||||
"safety-run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [
|
||||
inference_provider,
|
||||
Provider(
|
||||
provider_id="vllm-safety",
|
||||
provider_type="remote::vllm",
|
||||
config=VLLMInferenceAdapterConfig.sample_run_config(
|
||||
url="${env.SAFETY_VLLM_URL}",
|
||||
),
|
||||
),
|
||||
],
|
||||
},
|
||||
default_models=[
|
||||
inference_model,
|
||||
safety_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
),
|
||||
],
|
||||
default_shields=[
|
||||
ShieldInput(shield_id="${env.SAFETY_MODEL:Llama-Guard-3-1B}")
|
||||
],
|
||||
},
|
||||
docker_compose_env_vars={
|
||||
# these defaults are for the Docker Compose configuration
|
||||
"VLLM_URL": (
|
||||
"http://host.docker.internal:${VLLM_PORT:-5100}/v1",
|
||||
"URL of the vLLM server with the main inference model",
|
||||
),
|
||||
"SAFETY_VLLM_URL": (
|
||||
"http://host.docker.internal:${SAFETY_VLLM_PORT:-5101}/v1",
|
||||
"URL of the vLLM server with the safety model",
|
||||
),
|
||||
"MAX_TOKENS": (
|
||||
"${MAX_TOKENS:-4096}",
|
||||
"Maximum number of tokens for generation",
|
||||
"LLAMASTACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"${INFERENCE_MODEL:-Llama3.2-3B-Instruct}",
|
||||
"Name of the inference model to use",
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Inference model loaded into the vLLM server",
|
||||
),
|
||||
"VLLM_URL": (
|
||||
"http://host.docker.internal:5100}/v1",
|
||||
"URL of the vLLM server with the main inference model",
|
||||
),
|
||||
"MAX_TOKENS": (
|
||||
"4096",
|
||||
"Maximum number of tokens for generation",
|
||||
),
|
||||
"SAFETY_VLLM_URL": (
|
||||
"http://host.docker.internal:5101/v1",
|
||||
"URL of the vLLM server with the safety model",
|
||||
),
|
||||
"SAFETY_MODEL": (
|
||||
"${SAFETY_MODEL:-Llama-Guard-3-1B}",
|
||||
"meta-llama/Llama-Guard-3-1B",
|
||||
"Name of the safety (Llama-Guard) model to use",
|
||||
),
|
||||
"LLAMASTACK_PORT": (
|
||||
"${LLAMASTACK_PORT:-5001}",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue