mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-24 05:14:30 +00:00
docs: add post training to providers list (#2280)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Integration Tests / test-matrix (http, inference) (push) Failing after 11s
Integration Tests / test-matrix (http, datasets) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, inspect) (push) Failing after 12s
Integration Tests / test-matrix (http, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 1m18s
Pre-commit / pre-commit (push) Successful in 3m0s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Integration Tests / test-matrix (http, inference) (push) Failing after 11s
Integration Tests / test-matrix (http, datasets) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, inspect) (push) Failing after 12s
Integration Tests / test-matrix (http, agents) (push) Failing after 13s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, post_training) (push) Failing after 11s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inference) (push) Failing after 8s
Test External Providers / test-external-providers (venv) (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 8s
Integration Tests / test-matrix (library, providers) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 1m18s
Pre-commit / pre-commit (push) Successful in 3m0s
# What does this PR do? the providers list is missing post_training. Add that column and `HuggingFace`, `TorchTune`, and `NVIDIA NEMO` as supported providers. also point to these providers in docs/source/providers/index.md, and describe basic functionality There are other missing provider types here as well, but starting with this Signed-off-by: Charlie Doern <cdoern@redhat.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
parent
9b7f9db05c
commit
a7ecc92be1
5 changed files with 445 additions and 20 deletions
125
docs/source/providers/post_training/torchtune.md
Normal file
125
docs/source/providers/post_training/torchtune.md
Normal file
|
@ -0,0 +1,125 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# TorchTune
|
||||
|
||||
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
||||
|
||||
## Features
|
||||
|
||||
- Simple access through the post_training API
|
||||
- Fully integrated with Llama Stack
|
||||
- GPU support and single device capabilities.
|
||||
- Support for LoRA
|
||||
|
||||
## Usage
|
||||
|
||||
To use TorchTune in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Configure your Llama Stack project to use this provider.
|
||||
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||
|
||||
## Setup
|
||||
|
||||
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
||||
|
||||
```yaml
|
||||
post_training:
|
||||
- provider_id: torchtune
|
||||
provider_type: inline::torchtune
|
||||
config: {}
|
||||
```
|
||||
|
||||
you can then build and run your own stack with this provider.
|
||||
|
||||
## Run Training
|
||||
|
||||
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||
|
||||
```python
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack_client.types import (
|
||||
post_training_supervised_fine_tune_params,
|
||||
algorithm_config_param,
|
||||
)
|
||||
|
||||
|
||||
def create_http_client():
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
return LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
|
||||
client = create_http_client()
|
||||
|
||||
# Example Dataset
|
||||
client.datasets.register(
|
||||
purpose="post-training/messages",
|
||||
source={
|
||||
"type": "uri",
|
||||
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||
},
|
||||
dataset_id="simpleqa",
|
||||
)
|
||||
|
||||
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||
batch_size=32,
|
||||
data_format="instruct",
|
||||
dataset_id="simpleqa",
|
||||
shuffle=True,
|
||||
),
|
||||
gradient_accumulation_steps=1,
|
||||
max_steps_per_epoch=0,
|
||||
max_validation_steps=1,
|
||||
n_epochs=4,
|
||||
)
|
||||
|
||||
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||
alpha=1,
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
lora_attn_modules=["q_proj"],
|
||||
rank=1,
|
||||
type="LoRA",
|
||||
)
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
|
||||
# Example Model
|
||||
training_model = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
start_time = time.time()
|
||||
response = client.post_training.supervised_fine_tune(
|
||||
job_uuid=job_uuid,
|
||||
logger_config={},
|
||||
model=training_model,
|
||||
hyperparam_search_config={},
|
||||
training_config=training_config,
|
||||
algorithm_config=algorithm_config,
|
||||
checkpoint_dir="output",
|
||||
)
|
||||
print("Job: ", job_uuid)
|
||||
|
||||
# Wait for the job to complete!
|
||||
while True:
|
||||
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
print("Job not found")
|
||||
break
|
||||
|
||||
print(status)
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
print("Waiting for job to complete...")
|
||||
time.sleep(5)
|
||||
|
||||
end_time = time.time()
|
||||
print("Job completed in", end_time - start_time, "seconds!")
|
||||
|
||||
print("Artifacts:")
|
||||
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||
```
|
Loading…
Add table
Add a link
Reference in a new issue