mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
docs: add post training to providers list (#2280)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Integration Tests / test-matrix (http, inference) (push) Failing after 11s
Integration Tests / test-matrix (http, datasets) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, inspect) (push) Failing after 12s
Integration Tests / test-matrix (http, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 1m18s
Pre-commit / pre-commit (push) Successful in 3m0s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Integration Tests / test-matrix (http, inference) (push) Failing after 11s
Integration Tests / test-matrix (http, datasets) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, inspect) (push) Failing after 12s
Integration Tests / test-matrix (http, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 1m18s
Pre-commit / pre-commit (push) Successful in 3m0s
# What does this PR do? the providers list is missing post_training. Add that column and `HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers. also point to these providers in docs/source/providers/index.md, and describe basic functionality There are other missing provider types here as well, but starting with this Signed-off-by: Charlie Doern <cdoern@redhat.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
parent
9b7f9db05c
commit
a7ecc92be1
5 changed files with 445 additions and 20 deletions
43
README.md
43
README.md
|
@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||||
| SambaNova | Hosted | | ✅ | | ✅ | |
|
| SambaNova | Hosted | | ✅ | | ✅ | | |
|
||||||
| Cerebras | Hosted | | ✅ | | | |
|
| Cerebras | Hosted | | ✅ | | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
|
||||||
| Together | Hosted | ✅ | ✅ | | ✅ | |
|
| Together | Hosted | ✅ | ✅ | | ✅ | | |
|
||||||
| Groq | Hosted | | ✅ | | | |
|
| Groq | Hosted | | ✅ | | | | |
|
||||||
| Ollama | Single Node | | ✅ | | | |
|
| Ollama | Single Node | | ✅ | | | | |
|
||||||
| TGI | Hosted and Single Node | | ✅ | | | |
|
| TGI | Hosted and Single Node | | ✅ | | | | |
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | |
|
| Chroma | Single Node | | | ✅ | | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | |
|
| PG Vector | Single Node | | | ✅ | | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | | |
|
||||||
| watsonx | Hosted | | ✅ | | | |
|
| watsonx | Hosted | | ✅ | | | | |
|
||||||
|
| HuggingFace | Single Node | | | | | | ✅ |
|
||||||
|
| TorchTune | Single Node | | | | | | ✅ |
|
||||||
|
| NVIDIA NEMO | Hosted | | | | | | ✅ |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
|
@ -30,6 +30,18 @@ Runs inference with an LLM.
|
||||||
## Post Training
|
## Post Training
|
||||||
Fine-tunes a model.
|
Fine-tunes a model.
|
||||||
|
|
||||||
|
#### Post Training Providers
|
||||||
|
The following providers are available for Post Training:
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
external
|
||||||
|
post_training/huggingface
|
||||||
|
post_training/torchtune
|
||||||
|
post_training/nvidia_nemo
|
||||||
|
```
|
||||||
|
|
||||||
## Safety
|
## Safety
|
||||||
Applies safety policies to the output at a Systems (not only model) level.
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
|
122
docs/source/providers/post_training/huggingface.md
Normal file
122
docs/source/providers/post_training/huggingface.md
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# HuggingFace SFTTrainer
|
||||||
|
|
||||||
|
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a SFT job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the HuggingFace trainer via the `ollama` distribution:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# NVIDIA NEMO
|
||||||
|
|
||||||
|
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Enterprise-grade fine-tuning capabilities
|
||||||
|
- Support for LoRA and SFT fine-tuning
|
||||||
|
- Integration with NVIDIA's NeMo Customizer service
|
||||||
|
- Support for various NVIDIA-optimized models
|
||||||
|
- Efficient training with NVIDIA hardware acceleration
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Set up your NVIDIA API credentials.
|
||||||
|
3. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You'll need to set the following environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export NVIDIA_API_KEY="your-api-key"
|
||||||
|
export NVIDIA_DATASET_NAMESPACE="default"
|
||||||
|
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
|
||||||
|
export NVIDIA_PROJECT_ID="your-project-id"
|
||||||
|
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=8, # Default batch size for NEMO
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
n_epochs=50, # Default epochs for NEMO
|
||||||
|
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001, # Default learning rate
|
||||||
|
weight_decay=0.01, # NEMO-specific parameter
|
||||||
|
),
|
||||||
|
# NEMO-specific parameters
|
||||||
|
log_every_n_steps=None,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
sequence_packing_enabled=False,
|
||||||
|
hidden_dropout=None,
|
||||||
|
attention_dropout=None,
|
||||||
|
ffn_dropout=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=16, # Default alpha for NEMO
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model - must be a supported NEMO model
|
||||||
|
training_model = "meta/llama-3.1-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
Currently supports the following models:
|
||||||
|
- meta/llama-3.1-8b-instruct
|
||||||
|
- meta/llama-3.2-1b-instruct
|
||||||
|
|
||||||
|
## Supported Parameters
|
||||||
|
|
||||||
|
### TrainingConfig
|
||||||
|
- n_epochs (default: 50)
|
||||||
|
- data_config
|
||||||
|
- optimizer_config
|
||||||
|
- log_every_n_steps
|
||||||
|
- val_check_interval (default: 0.25)
|
||||||
|
- sequence_packing_enabled (default: False)
|
||||||
|
- hidden_dropout (0.0-1.0)
|
||||||
|
- attention_dropout (0.0-1.0)
|
||||||
|
- ffn_dropout (0.0-1.0)
|
||||||
|
|
||||||
|
### DataConfig
|
||||||
|
- dataset_id
|
||||||
|
- batch_size (default: 8)
|
||||||
|
|
||||||
|
### OptimizerConfig
|
||||||
|
- lr (default: 0.0001)
|
||||||
|
- weight_decay (default: 0.01)
|
||||||
|
|
||||||
|
### LoRA Config
|
||||||
|
- alpha (default: 16)
|
||||||
|
- type (must be "LoRA")
|
||||||
|
|
||||||
|
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.
|
125
docs/source/providers/post_training/torchtune.md
Normal file
125
docs/source/providers/post_training/torchtune.md
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# TorchTune
|
||||||
|
|
||||||
|
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support and single device capabilities.
|
||||||
|
- Support for LoRA
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use TorchTune in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config: {}
|
||||||
|
```
|
||||||
|
|
||||||
|
you can then build and run your own stack with this provider.
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
Loading…
Add table
Add a link
Reference in a new issue