mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 05:39:47 +00:00
feat: add huggingface post_training impl (#2132)
# What does this PR do? adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc the provider comes with one recipe `finetune_single_device` which works both with and without LoRA. any model that is a valid HF identifier can be given and the model will be pulled. this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training. ## Test Plan re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step. [//]: # (## Documentation) --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
8f9964f46b
commit
f02f7b28c1
20 changed files with 1181 additions and 201 deletions
137
requirements.txt
137
requirements.txt
|
@ -1,206 +1,69 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# llama-stack-client
|
||||
# openai
|
||||
attrs==25.1.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
blobfile==3.0.0
|
||||
# via llama-stack
|
||||
cachetools==5.5.2
|
||||
# via google-auth
|
||||
certifi==2025.1.31
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# kubernetes
|
||||
# requests
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via llama-stack-client
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
distro==1.9.0
|
||||
# via
|
||||
# llama-stack-client
|
||||
# openai
|
||||
durationpy==0.9
|
||||
# via kubernetes
|
||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||
# via anyio
|
||||
filelock==3.17.0
|
||||
# via
|
||||
# blobfile
|
||||
# huggingface-hub
|
||||
fire==0.7.0
|
||||
# via llama-stack
|
||||
fsspec==2024.12.0
|
||||
# via huggingface-hub
|
||||
google-auth==2.38.0
|
||||
# via kubernetes
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# llama-stack
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
huggingface-hub==0.29.0
|
||||
# via llama-stack
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
jinja2==3.1.6
|
||||
# via llama-stack
|
||||
jiter==0.8.2
|
||||
# via openai
|
||||
jsonschema==4.23.0
|
||||
# via llama-stack
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
kubernetes==32.0.1
|
||||
# via llama-stack
|
||||
llama-stack-client==0.2.7
|
||||
# via llama-stack
|
||||
lxml==5.3.1
|
||||
# via blobfile
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==2.2.3
|
||||
# via pandas
|
||||
oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
openai==1.71.0
|
||||
# via llama-stack
|
||||
packaging==24.2
|
||||
# via huggingface-hub
|
||||
pandas==2.2.3
|
||||
# via llama-stack-client
|
||||
pillow==11.1.0
|
||||
# via llama-stack
|
||||
prompt-toolkit==3.0.50
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
pyaml==25.1.0
|
||||
# via llama-stack-client
|
||||
pyasn1==0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pycryptodomex==3.21.0
|
||||
# via blobfile
|
||||
pydantic==2.10.6
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# kubernetes
|
||||
# pandas
|
||||
python-dotenv==1.0.1
|
||||
# via llama-stack
|
||||
pytz==2025.1
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
# pyaml
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via tiktoken
|
||||
requests==2.32.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
# llama-stack
|
||||
# requests-oauthlib
|
||||
# tiktoken
|
||||
requests-oauthlib==2.0.0
|
||||
# via kubernetes
|
||||
rich==13.9.4
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
setuptools==75.8.0
|
||||
# via llama-stack
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
# python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# llama-stack-client
|
||||
# openai
|
||||
termcolor==2.5.0
|
||||
# via
|
||||
# fire
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
tiktoken==0.9.0
|
||||
# via llama-stack
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# llama-stack-client
|
||||
# openai
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# huggingface-hub
|
||||
# llama-stack-client
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# rich
|
||||
tzdata==2025.1
|
||||
# via pandas
|
||||
urllib3==2.3.0
|
||||
# via
|
||||
# blobfile
|
||||
# kubernetes
|
||||
# requests
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.8.0
|
||||
# via kubernetes
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue