diff --git a/README.md b/README.md index e54b505cf..37f1aa0f3 100644 --- a/README.md +++ b/README.md @@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on ### API Providers Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. -| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | -|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| -| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | -| SambaNova | Hosted | | ✅ | | ✅ | | -| Cerebras | Hosted | | ✅ | | | | -| Fireworks | Hosted | ✅ | ✅ | ✅ | | | -| AWS Bedrock | Hosted | | ✅ | | ✅ | | -| Together | Hosted | ✅ | ✅ | | ✅ | | -| Groq | Hosted | | ✅ | | | | -| Ollama | Single Node | | ✅ | | | | -| TGI | Hosted and Single Node | | ✅ | | | | -| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | -| Chroma | Single Node | | | ✅ | | | -| PG Vector | Single Node | | | ✅ | | | -| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | -| vLLM | Hosted and Single Node | | ✅ | | | | -| OpenAI | Hosted | | ✅ | | | | -| Anthropic | Hosted | | ✅ | | | | -| Gemini | Hosted | | ✅ | | | | -| watsonx | Hosted | | ✅ | | | | +| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** | +|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:| +| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | | +| SambaNova | Hosted | | ✅ | | ✅ | | | +| Cerebras | Hosted | | ✅ | | | | | +| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | +| AWS Bedrock | Hosted | | ✅ | | ✅ | | | +| Together | Hosted | ✅ | ✅ | | ✅ | | | +| Groq | Hosted | | ✅ | | | | | +| Ollama | Single Node | | ✅ | | | | | +| TGI | Hosted and Single Node | | ✅ | | | | | +| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | +| Chroma | Single Node | | | ✅ | | | | +| PG Vector | Single Node | | | ✅ | | | | +| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | +| vLLM | Hosted and Single Node | | ✅ | | | | | +| OpenAI | Hosted | | ✅ | | | | | +| Anthropic | Hosted | | ✅ | | | | | +| Gemini | Hosted | | ✅ | | | | | +| watsonx | Hosted | | ✅ | | | | | +| HuggingFace | Single Node | | | | | | ✅ | +| TorchTune | Single Node | | | | | | ✅ | +| NVIDIA NEMO | Hosted | | | | | | ✅ | ### Distributions diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index 1d1a6e081..1f5026479 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -30,6 +30,18 @@ Runs inference with an LLM. ## Post Training Fine-tunes a model. +#### Post Training Providers +The following providers are available for Post Training: + +```{toctree} +:maxdepth: 1 + +external +post_training/huggingface +post_training/torchtune +post_training/nvidia_nemo +``` + ## Safety Applies safety policies to the output at a Systems (not only model) level. diff --git a/docs/source/providers/post_training/huggingface.md b/docs/source/providers/post_training/huggingface.md new file mode 100644 index 000000000..c342203a8 --- /dev/null +++ b/docs/source/providers/post_training/huggingface.md @@ -0,0 +1,122 @@ +--- +orphan: true +--- +# HuggingFace SFTTrainer + +[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets + +## Features + +- Simple access through the post_training API +- Fully integrated with Llama Stack +- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders) + +## Usage + +To use the HF SFTTrainer in your Llama Stack project, follow these steps: + +1. Configure your Llama Stack project to use this provider. +2. Kick off a SFT job using the Llama Stack post_training API. + +## Setup + +You can access the HuggingFace trainer via the `ollama` distribution: + +```bash +llama stack build --template ollama --image-type venv +llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml +``` + +## Run Training + +You can access the provider and the `supervised_fine_tune` method via the post_training API: + +```python +import time +import uuid + + +from llama_stack_client.types import ( + post_training_supervised_fine_tune_params, + algorithm_config_param, +) + + +def create_http_client(): + from llama_stack_client import LlamaStackClient + + return LlamaStackClient(base_url="http://localhost:8321") + + +client = create_http_client() + +# Example Dataset +client.datasets.register( + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + dataset_id="simpleqa", +) + +training_config = post_training_supervised_fine_tune_params.TrainingConfig( + data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( + batch_size=32, + data_format="instruct", + dataset_id="simpleqa", + shuffle=True, + ), + gradient_accumulation_steps=1, + max_steps_per_epoch=0, + max_validation_steps=1, + n_epochs=4, +) + +algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be + alpha=1, + apply_lora_to_mlp=True, + apply_lora_to_output=False, + lora_attn_modules=["q_proj"], + rank=1, + type="LoRA", +) + +job_uuid = f"test-job{uuid.uuid4()}" + +# Example Model +training_model = "ibm-granite/granite-3.3-8b-instruct" + +start_time = time.time() +response = client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + logger_config={}, + model=training_model, + hyperparam_search_config={}, + training_config=training_config, + algorithm_config=algorithm_config, + checkpoint_dir="output", +) +print("Job: ", job_uuid) + + +# Wait for the job to complete! +while True: + status = client.post_training.job.status(job_uuid=job_uuid) + if not status: + print("Job not found") + break + + print(status) + if status.status == "completed": + break + + print("Waiting for job to complete...") + time.sleep(5) + +end_time = time.time() +print("Job completed in", end_time - start_time, "seconds!") + +print("Artifacts:") +print(client.post_training.job.artifacts(job_uuid=job_uuid)) +``` diff --git a/docs/source/providers/post_training/nvidia_nemo.md b/docs/source/providers/post_training/nvidia_nemo.md new file mode 100644 index 000000000..1a7adbe16 --- /dev/null +++ b/docs/source/providers/post_training/nvidia_nemo.md @@ -0,0 +1,163 @@ +--- +orphan: true +--- +# NVIDIA NEMO + +[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service. + +## Features + +- Enterprise-grade fine-tuning capabilities +- Support for LoRA and SFT fine-tuning +- Integration with NVIDIA's NeMo Customizer service +- Support for various NVIDIA-optimized models +- Efficient training with NVIDIA hardware acceleration + +## Usage + +To use NVIDIA NEMO in your Llama Stack project, follow these steps: + +1. Configure your Llama Stack project to use this provider. +2. Set up your NVIDIA API credentials. +3. Kick off a fine-tuning job using the Llama Stack post_training API. + +## Setup + +You'll need to set the following environment variables: + +```bash +export NVIDIA_API_KEY="your-api-key" +export NVIDIA_DATASET_NAMESPACE="default" +export NVIDIA_CUSTOMIZER_URL="your-customizer-url" +export NVIDIA_PROJECT_ID="your-project-id" +export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir" +``` + +## Run Training + +You can access the provider and the `supervised_fine_tune` method via the post_training API: + +```python +import time +import uuid + +from llama_stack_client.types import ( + post_training_supervised_fine_tune_params, + algorithm_config_param, +) + + +def create_http_client(): + from llama_stack_client import LlamaStackClient + + return LlamaStackClient(base_url="http://localhost:8321") + + +client = create_http_client() + +# Example Dataset +client.datasets.register( + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + dataset_id="simpleqa", +) + +training_config = post_training_supervised_fine_tune_params.TrainingConfig( + data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( + batch_size=8, # Default batch size for NEMO + data_format="instruct", + dataset_id="simpleqa", + shuffle=True, + ), + n_epochs=50, # Default epochs for NEMO + optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig( + lr=0.0001, # Default learning rate + weight_decay=0.01, # NEMO-specific parameter + ), + # NEMO-specific parameters + log_every_n_steps=None, + val_check_interval=0.25, + sequence_packing_enabled=False, + hidden_dropout=None, + attention_dropout=None, + ffn_dropout=None, +) + +algorithm_config = algorithm_config_param.LoraFinetuningConfig( + alpha=16, # Default alpha for NEMO + type="LoRA", +) + +job_uuid = f"test-job{uuid.uuid4()}" + +# Example Model - must be a supported NEMO model +training_model = "meta/llama-3.1-8b-instruct" + +start_time = time.time() +response = client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + logger_config={}, + model=training_model, + hyperparam_search_config={}, + training_config=training_config, + algorithm_config=algorithm_config, + checkpoint_dir="output", +) +print("Job: ", job_uuid) + +# Wait for the job to complete! +while True: + status = client.post_training.job.status(job_uuid=job_uuid) + if not status: + print("Job not found") + break + + print(status) + if status.status == "completed": + break + + print("Waiting for job to complete...") + time.sleep(5) + +end_time = time.time() +print("Job completed in", end_time - start_time, "seconds!") + +print("Artifacts:") +print(client.post_training.job.artifacts(job_uuid=job_uuid)) +``` + +## Supported Models + +Currently supports the following models: +- meta/llama-3.1-8b-instruct +- meta/llama-3.2-1b-instruct + +## Supported Parameters + +### TrainingConfig +- n_epochs (default: 50) +- data_config +- optimizer_config +- log_every_n_steps +- val_check_interval (default: 0.25) +- sequence_packing_enabled (default: False) +- hidden_dropout (0.0-1.0) +- attention_dropout (0.0-1.0) +- ffn_dropout (0.0-1.0) + +### DataConfig +- dataset_id +- batch_size (default: 8) + +### OptimizerConfig +- lr (default: 0.0001) +- weight_decay (default: 0.01) + +### LoRA Config +- alpha (default: 16) +- type (must be "LoRA") + +Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning. diff --git a/docs/source/providers/post_training/torchtune.md b/docs/source/providers/post_training/torchtune.md new file mode 100644 index 000000000..ef72505b1 --- /dev/null +++ b/docs/source/providers/post_training/torchtune.md @@ -0,0 +1,125 @@ +--- +orphan: true +--- +# TorchTune + +[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch. + +## Features + +- Simple access through the post_training API +- Fully integrated with Llama Stack +- GPU support and single device capabilities. +- Support for LoRA + +## Usage + +To use TorchTune in your Llama Stack project, follow these steps: + +1. Configure your Llama Stack project to use this provider. +2. Kick off a fine-tuning job using the Llama Stack post_training API. + +## Setup + +You can access the TorchTune trainer by writing your own yaml pointing to the provider: + +```yaml +post_training: + - provider_id: torchtune + provider_type: inline::torchtune + config: {} +``` + +you can then build and run your own stack with this provider. + +## Run Training + +You can access the provider and the `supervised_fine_tune` method via the post_training API: + +```python +import time +import uuid + +from llama_stack_client.types import ( + post_training_supervised_fine_tune_params, + algorithm_config_param, +) + + +def create_http_client(): + from llama_stack_client import LlamaStackClient + + return LlamaStackClient(base_url="http://localhost:8321") + + +client = create_http_client() + +# Example Dataset +client.datasets.register( + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + dataset_id="simpleqa", +) + +training_config = post_training_supervised_fine_tune_params.TrainingConfig( + data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig( + batch_size=32, + data_format="instruct", + dataset_id="simpleqa", + shuffle=True, + ), + gradient_accumulation_steps=1, + max_steps_per_epoch=0, + max_validation_steps=1, + n_epochs=4, +) + +algorithm_config = algorithm_config_param.LoraFinetuningConfig( + alpha=1, + apply_lora_to_mlp=True, + apply_lora_to_output=False, + lora_attn_modules=["q_proj"], + rank=1, + type="LoRA", +) + +job_uuid = f"test-job{uuid.uuid4()}" + +# Example Model +training_model = "meta-llama/Llama-2-7b-hf" + +start_time = time.time() +response = client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + logger_config={}, + model=training_model, + hyperparam_search_config={}, + training_config=training_config, + algorithm_config=algorithm_config, + checkpoint_dir="output", +) +print("Job: ", job_uuid) + +# Wait for the job to complete! +while True: + status = client.post_training.job.status(job_uuid=job_uuid) + if not status: + print("Job not found") + break + + print(status) + if status.status == "completed": + break + + print("Waiting for job to complete...") + time.sleep(5) + +end_time = time.time() +print("Job completed in", end_time - start_time, "seconds!") + +print("Artifacts:") +print(client.post_training.job.artifacts(job_uuid=job_uuid)) +```