Adding docker-compose.yaml, starting to simplify

This commit is contained in:
Ashwin Bharambe 2024-11-16 10:56:38 -08:00
parent e4509cb568
commit f38e76ee98
14 changed files with 516 additions and 386 deletions

View file

@ -9,7 +9,7 @@ from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple
import jinja2
import yaml
@ -22,7 +22,6 @@ from llama_stack.distribution.datatypes import (
Api,
BuildConfig,
DistributionSpec,
KVStoreConfig,
ModelInput,
Provider,
ShieldInput,
@ -33,53 +32,26 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.remote.inference.vllm.config import (
VLLMInferenceAdapterConfig,
)
from llama_stack.providers.utils.docker.service_config import DockerComposeServiceConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class DistributionTemplate(BaseModel):
"""
Represents a Llama Stack distribution instance that can generate configuration
and documentation files.
"""
name: str
description: str
providers: Dict[str, List[str]]
run_config_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
compose_config_overrides: Dict[str, Dict[str, DockerComposeServiceConfig]] = Field(
default_factory=dict
)
class RunConfigSettings(BaseModel):
provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
default_models: List[ModelInput]
default_shields: Optional[List[ShieldInput]] = None
# Optional configuration
metadata_store: Optional[KVStoreConfig] = None
docker_compose_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
docker_image: Optional[str] = None
@property
def distribution_spec(self) -> DistributionSpec:
return DistributionSpec(
description=self.description,
docker_image=self.docker_image,
providers=self.providers,
)
def build_config(self) -> BuildConfig:
return BuildConfig(
name=self.name,
distribution_spec=self.distribution_spec,
image_type="conda", # default to conda, can be overridden
)
def run_config(self) -> StackRunConfig:
def run_config(
self,
name: str,
providers: Dict[str, List[str]],
docker_image: Optional[str] = None,
) -> StackRunConfig:
provider_registry = get_provider_registry()
provider_configs = {}
for api_str, provider_types in self.providers.items():
if providers := self.run_config_overrides.get(api_str):
provider_configs[api_str] = providers
for api_str, provider_types in providers.items():
if api_providers := self.provider_overrides.get(api_str):
provider_configs[api_str] = api_providers
continue
provider_type = provider_types[0]
@ -111,83 +83,53 @@ class DistributionTemplate(BaseModel):
]
# Get unique set of APIs from providers
apis: Set[str] = set(self.providers.keys())
apis: Set[str] = set(providers.keys())
return StackRunConfig(
image_name=self.name,
docker_image=self.docker_image,
image_name=name,
docker_image=docker_image,
built_at=datetime.now(),
apis=list(apis),
providers=provider_configs,
metadata_store=self.metadata_store,
metadata_store=SqliteKVStoreConfig.sample_run_config(
dir=f"distributions/{name}",
db_name="registry.db",
),
models=self.default_models,
shields=self.default_shields or [],
)
def docker_compose_config(self) -> Dict[str, Any]:
services = {}
provider_registry = get_provider_registry()
# Add provider services based on their sample_compose_config
for api_str, api_providers in self.providers.items():
if overrides := self.compose_config_overrides.get(api_str):
services |= overrides
continue
class DistributionTemplate(BaseModel):
"""
Represents a Llama Stack distribution instance that can generate configuration
and documentation files.
"""
# only look at the first provider to get the compose config for now
# we may want to use `docker compose profiles` in the future
provider_type = api_providers[0]
provider_id = provider_type.split("::")[-1]
api = Api(api_str)
if provider_type not in provider_registry[api]:
raise ValueError(
f"Unknown provider type: {provider_type} for API: {api_str}"
)
name: str
description: str
config_class = provider_registry[api][provider_type].config_class
assert (
config_class is not None
), f"No config class for provider type: {provider_type} for API: {api_str}"
providers: Dict[str, List[str]]
run_configs: Dict[str, RunConfigSettings]
template_path: Path
config_class = instantiate_class_type(config_class)
if not hasattr(config_class, "sample_docker_compose_config"):
continue
# Optional configuration
docker_compose_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
docker_image: Optional[str] = None
compose_config = config_class.sample_docker_compose_config()
services[provider_id] = compose_config
default_models: Optional[List[ModelInput]] = None
port = "${LLAMASTACK_PORT:-5001}"
# Add main llamastack service
llamastack_config = DockerComposeServiceConfig(
image=f"llamastack/distribution-{self.name}:latest",
depends_on=list(services.keys()),
volumes=[
"~/.llama:/root/.llama",
f"~/local/llama-stack/distributions/{self.name}/run.yaml:/root/llamastack-run-{self.name}.yaml",
],
ports=[f"{port}:{port}"],
environment={
k: v[0] for k, v in (self.docker_compose_env_vars or {}).items()
},
entrypoint=(
f'bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-{self.name}.yaml --port {port}"'
def build_config(self) -> BuildConfig:
return BuildConfig(
name=self.name,
distribution_spec=DistributionSpec(
description=self.description,
docker_image=self.docker_image,
providers=self.providers,
),
deploy={
"restart_policy": {
"condition": "on-failure",
"delay": "3s",
"max_attempts": 5,
"window": "60s",
}
},
image_type="conda", # default to conda, can be overridden
)
services["llamastack"] = llamastack_config
return {
"services": {k: v.model_dump() for k, v in services.items()},
"volumes": {service_name: None for service_name in services.keys()},
}
def generate_markdown_docs(self) -> str:
"""Generate markdown documentation using both Jinja2 templates and rich tables."""
# First generate the providers table using rich
@ -204,53 +146,7 @@ class DistributionTemplate(BaseModel):
console.print(table)
providers_table = output.getvalue()
# Main documentation template
template = """# {{ name }} Distribution
{{ description }}
## Provider Configuration
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
{%- if env_vars %}
## Environment Variables
The following environment variables can be configured:
{% for var, (value, description) in docker_compose_env_vars.items() %}
- `{{ var }}`: {{ description }}
{% endfor %}
{%- endif %}
## Example Usage
### Using Docker Compose
```bash
$ cd distributions/{{ name }}
$ docker compose up
```
## Models
The following models are configured by default:
{% for model in default_models %}
- `{{ model.model_id }}`
{% endfor %}
{%- if default_shields %}
## Safety Shields
The following safety shields are configured:
{% for shield in default_shields %}
- `{{ shield.shield_id }}`
{%- endfor %}
{%- endif %}
"""
template = self.template_path.read_text()
# Render template with rich-generated table
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
template = env.from_string(template)
@ -261,7 +157,6 @@ The following safety shields are configured:
providers_table=providers_table,
docker_compose_env_vars=self.docker_compose_env_vars,
default_models=self.default_models,
default_shields=self.default_shields,
)
def save_distribution(self, output_dir: Path) -> None:
@ -271,19 +166,14 @@ The following safety shields are configured:
with open(output_dir / "build.yaml", "w") as f:
yaml.safe_dump(build_config.model_dump(), f, sort_keys=False)
run_config = self.run_config()
serialized = run_config.model_dump()
with open(output_dir / "run.yaml", "w") as f:
yaml.safe_dump(serialized, f, sort_keys=False)
# serialized_str = yaml.dump(serialized, sort_keys=False)
# env_vars = set()
# for match in re.finditer(r"\${env\.([A-Za-z0-9_-]+)}", serialized_str):
# env_vars.add(match.group(1))
docker_compose = self.docker_compose_config()
with open(output_dir / "compose.yaml", "w") as f:
yaml.safe_dump(docker_compose, f, sort_keys=False, default_flow_style=False)
for yaml_pth, settings in self.run_configs.items():
print(f"Generating {yaml_pth}")
print(f"Providers: {self.providers}")
run_config = settings.run_config(
self.name, self.providers, self.docker_image
)
with open(output_dir / yaml_pth, "w") as f:
yaml.safe_dump(run_config.model_dump(), f, sort_keys=False)
docs = self.generate_markdown_docs()
with open(output_dir / f"{self.name}.md", "w") as f:
@ -291,87 +181,89 @@ The following safety shields are configured:
@classmethod
def vllm_distribution(cls) -> "DistributionTemplate":
providers = {
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="vllm-inference",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL}",
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="vllm-inference",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="vllm-safety",
)
return cls(
name="remote-vllm",
description="Use (an external) vLLM server for running LLM inference",
providers={
"inference": ["remote::vllm"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
},
run_config_overrides={
"inference": [
Provider(
provider_id="vllm-0",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.VLLM_URL:http://host.docker.internal:5100/v1}",
),
),
Provider(
provider_id="vllm-1",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.SAFETY_VLLM_URL:http://host.docker.internal:5101/v1}",
),
),
]
},
compose_config_overrides={
"inference": {
"vllm-0": VLLMInferenceAdapterConfig.sample_docker_compose_config(
port=5100,
cuda_visible_devices="0",
model="${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
),
"vllm-1": VLLMInferenceAdapterConfig.sample_docker_compose_config(
port=5100,
cuda_visible_devices="1",
model="${env.SAFETY_MODEL:Llama-Guard-3-1B}",
),
}
},
default_models=[
ModelInput(
model_id="${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
provider_id="vllm-0",
template_path=Path(__file__).parent / "remote-vllm" / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=[inference_model],
),
ModelInput(
model_id="${env.SAFETY_MODEL:Llama-Guard-3-1B}",
provider_id="vllm-1",
"safety-run.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="vllm-safety",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig.sample_run_config(
url="${env.SAFETY_VLLM_URL}",
),
),
],
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
],
default_shields=[
ShieldInput(shield_id="${env.SAFETY_MODEL:Llama-Guard-3-1B}")
],
},
docker_compose_env_vars={
# these defaults are for the Docker Compose configuration
"VLLM_URL": (
"http://host.docker.internal:${VLLM_PORT:-5100}/v1",
"URL of the vLLM server with the main inference model",
),
"SAFETY_VLLM_URL": (
"http://host.docker.internal:${SAFETY_VLLM_PORT:-5101}/v1",
"URL of the vLLM server with the safety model",
),
"MAX_TOKENS": (
"${MAX_TOKENS:-4096}",
"Maximum number of tokens for generation",
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"${INFERENCE_MODEL:-Llama3.2-3B-Instruct}",
"Name of the inference model to use",
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the vLLM server",
),
"VLLM_URL": (
"http://host.docker.internal:5100}/v1",
"URL of the vLLM server with the main inference model",
),
"MAX_TOKENS": (
"4096",
"Maximum number of tokens for generation",
),
"SAFETY_VLLM_URL": (
"http://host.docker.internal:5101/v1",
"URL of the vLLM server with the safety model",
),
"SAFETY_MODEL": (
"${SAFETY_MODEL:-Llama-Guard-3-1B}",
"meta-llama/Llama-Guard-3-1B",
"Name of the safety (Llama-Guard) model to use",
),
"LLAMASTACK_PORT": (
"${LLAMASTACK_PORT:-5001}",
"Port for the Llama Stack distribution server",
),
},
)