From 7519b73fcc79c9eef3e83cfd453f2087c300dc96 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 22 Aug 2025 15:47:15 -0700 Subject: [PATCH] feat(distro): fork off a starter-gpu distribution (#3240) The starter distribution added post-training which added torch dependencies which pulls in all the nvidia CUDA libraries. This made our starter container very big. We have worked hard to keep the starter container small so it serves its purpose as a starter. This PR tries to get it back to its size by forking off duplicate "-gpu" providers for post-training. These forked providers are then used for a new `starter-gpu` distribution which can pull in all dependencies. --- docs/source/providers/post_training/index.md | 6 +- .../post_training/inline_huggingface-cpu.md | 41 +++ .../post_training/inline_huggingface-gpu.md | 41 +++ .../post_training/inline_torchtune-cpu.md | 20 ++ .../post_training/inline_torchtune-gpu.md | 20 ++ llama_stack/distributions/ci-tests/build.yaml | 2 +- llama_stack/distributions/ci-tests/run.yaml | 4 +- .../distributions/starter-gpu/__init__.py | 7 + .../distributions/starter-gpu/build.yaml | 59 +++++ .../distributions/starter-gpu/run.yaml | 238 ++++++++++++++++++ .../distributions/starter-gpu/starter_gpu.py | 22 ++ llama_stack/distributions/starter/build.yaml | 5 +- llama_stack/distributions/starter/run.yaml | 4 +- llama_stack/distributions/starter/starter.py | 4 +- .../providers/registry/post_training.py | 80 ++++-- 15 files changed, 522 insertions(+), 31 deletions(-) create mode 100644 docs/source/providers/post_training/inline_huggingface-cpu.md create mode 100644 docs/source/providers/post_training/inline_huggingface-gpu.md create mode 100644 docs/source/providers/post_training/inline_torchtune-cpu.md create mode 100644 docs/source/providers/post_training/inline_torchtune-gpu.md create mode 100644 llama_stack/distributions/starter-gpu/__init__.py create mode 100644 llama_stack/distributions/starter-gpu/build.yaml create mode 100644 llama_stack/distributions/starter-gpu/run.yaml create mode 100644 llama_stack/distributions/starter-gpu/starter_gpu.py diff --git a/docs/source/providers/post_training/index.md b/docs/source/providers/post_training/index.md index c6c92c40e..5ada6f9aa 100644 --- a/docs/source/providers/post_training/index.md +++ b/docs/source/providers/post_training/index.md @@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t ```{toctree} :maxdepth: 1 -inline_huggingface -inline_torchtune +inline_huggingface-cpu +inline_huggingface-gpu +inline_torchtune-cpu +inline_torchtune-gpu remote_nvidia ``` diff --git a/docs/source/providers/post_training/inline_huggingface-cpu.md b/docs/source/providers/post_training/inline_huggingface-cpu.md new file mode 100644 index 000000000..e663fe8f8 --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-cpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-cpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_huggingface-gpu.md b/docs/source/providers/post_training/inline_huggingface-gpu.md new file mode 100644 index 000000000..21bf965fe --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-gpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-gpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-cpu.md b/docs/source/providers/post_training/inline_torchtune-cpu.md new file mode 100644 index 000000000..7204e56e8 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-cpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-cpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-gpu.md b/docs/source/providers/post_training/inline_torchtune-gpu.md new file mode 100644 index 000000000..98b94f6f6 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-gpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-gpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 0bf42e7ee..b4701cb81 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -34,7 +34,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 02a268462..3acdd20f9 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -156,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null diff --git a/llama_stack/distributions/starter-gpu/__init__.py b/llama_stack/distributions/starter-gpu/__init__.py new file mode 100644 index 000000000..e762f9b6e --- /dev/null +++ b/llama_stack/distributions/starter-gpu/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .starter_gpu import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml new file mode 100644 index 000000000..ae0680cdc --- /dev/null +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -0,0 +1,59 @@ +version: 2 +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for GPU-enabled environments. + providers: + inference: + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::nvidia + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::vertexai + - provider_type: remote::groq + - provider_type: remote::sambanova + - provider_type: inline::sentence-transformers + vector_io: + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector + files: + - provider_type: inline::localfs + safety: + - provider_type: inline::llama-guard + - provider_type: inline::code-scanner + agents: + - provider_type: inline::meta-reference + telemetry: + - provider_type: inline::meta-reference + post_training: + - provider_type: inline::torchtune-gpu + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference +image_type: venv +additional_pip_packages: +- aiosqlite +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml new file mode 100644 index 000000000..81c802317 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -0,0 +1,238 @@ +version: 2 +image_name: starter-gpu +apis: +- agents +- batches +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ${env.CEREBRAS_API_KEY:+cerebras} + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY:=} + - provider_id: ${env.OLLAMA_URL:+ollama} + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + - provider_id: ${env.VLLM_URL:+vllm} + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: ${env.TGI_URL:+tgi} + provider_type: remote::tgi + config: + url: ${env.TGI_URL:=} + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:=} + - provider_id: together + provider_type: remote::together + config: + url: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY:=} + - provider_id: bedrock + provider_type: remote::bedrock + - provider_id: ${env.NVIDIA_API_KEY:+nvidia} + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:=} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} + - provider_id: anthropic + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY:=} + - provider_id: gemini + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY:=} + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY:=} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + - provider_id: ${env.MILVUS_URL:+milvus} + provider_type: inline::milvus + config: + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + - provider_id: ${env.CHROMADB_URL:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + - provider_id: ${env.PGVECTOR_DB:+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:=localhost} + port: ${env.PGVECTOR_PORT:=5432} + db: ${env.PGVECTOR_DB:=} + user: ${env.PGVECTOR_USER:=} + password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + post_training: + - provider_id: torchtune-gpu + provider_type: inline::torchtune-gpu + config: + checkpoint_format: meta + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db +models: [] +shields: +- shield_id: llama-guard + provider_id: ${env.SAFETY_MODEL:+llama-guard} + provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py new file mode 100644 index 000000000..893df6c17 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distributions.template import BuildProvider, DistributionTemplate + +from ..starter.starter import get_distribution_template as get_starter_distribution_template + + +def get_distribution_template() -> DistributionTemplate: + template = get_starter_distribution_template() + name = "starter-gpu" + template.name = name + template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." + + template.providers["post_training"] = [ + BuildProvider(provider_type="inline::torchtune-gpu"), + ] + return template diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 2ad12a165..3df0eb129 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -1,6 +1,7 @@ version: 2 distribution_spec: - description: Quick start template for running Llama Stack with several popular providers + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for CPU-only environments. providers: inference: - provider_type: remote::cerebras @@ -34,7 +35,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 7ac4dc6b9..7e1d46a61 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -156,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index cad3d72d9..f49da0bb7 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate: ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], - "post_training": [BuildProvider(provider_type="inline::huggingface")], + "post_training": [BuildProvider(provider_type="inline::huggingface-cpu")], "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ BuildProvider(provider_type="remote::huggingface"), @@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Quick start template for running Llama Stack with several popular providers", + description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.", container_image=None, template_path=None, providers=providers, diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index ffd64ef7c..4443f4df1 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -5,34 +5,74 @@ # the root directory of this source tree. +from typing import cast + from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +# We provide two versions of these providers so that distributions can package the appropriate version of torch. +# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. +torchtune_def = dict( + api=Api.post_training, + pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", +) + +huggingface_def = dict( + api=Api.post_training, + pip_packages=["trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", +) + def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( - api=Api.post_training, - provider_type="inline::torchtune", - pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"], - module="llama_stack.providers.inline.post_training.torchtune", - config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", + **{ + **torchtune_def, + "provider_type": "inline::torchtune-cpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + + ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"] + ), + }, ), InlineProviderSpec( - api=Api.post_training, - provider_type="inline::huggingface", - pip_packages=["torch", "trl", "transformers", "peft", "datasets"], - module="llama_stack.providers.inline.post_training.huggingface", - config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", + **{ + **huggingface_def, + "provider_type": "inline::huggingface-cpu", + "pip_packages": ( + cast(list[str], huggingface_def["pip_packages"]) + + ["torch --index-url https://download.pytorch.org/whl/cpu"] + ), + }, + ), + InlineProviderSpec( + **{ + **torchtune_def, + "provider_type": "inline::torchtune-gpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"] + ), + }, + ), + InlineProviderSpec( + **{ + **huggingface_def, + "provider_type": "inline::huggingface-gpu", + "pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]), + }, ), remote_provider_spec( api=Api.post_training,