### Context This is the 1st of series PRs that integrate torchtune with llama-stack as meta reference post-training implementation. For MVP, we will focus on single device LoRA SFT. Though this PR is still WIP, we want to get early feedback on the high level design of this skeleton while still working on several details ### Scope To limit the scope of this PR, we focus on the skeleton of the implementation. **What are included?** - refine the post-training SFT apis - skeleton of supervised_fine_tune implementation. We verified that we can call the supervised_fine_tune API successfully from llama stack client SDK (client side PR: https://github.com/meta-llama/llama-stack-client-python/pull/51) - a very basic single device LoRA training recipe based on torchtune core components - parity check with torchtune library and post training api unit test **What are not includes?** - implementation of other job management, get training artifacts apis (separate PR) - refactor the meta reference inference logic to support eval on finetuned model (separate PR) - several necessary functionality in the training recipe such as logging, validation etc (separate PR) - interop with telemetry for tracing and metrics logging, currently temporarily log to local disk (separate PR) ### Testing **e2e test** Although we haven't added detailed testing and numerical parity check with torchtune yet, we did a simple E2E test from client to server 1. setup server with` llama stack build --template experimental-post-training --image-type conda` and `llama stack run experimental-post-training ` 2. On client, run `llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 post_training supervised_fine_tune` 3. Training finishes successfully. On server side, get the finetune checkpoints under output dir. On client side, get the job uuid server <img width="1110" alt="Screenshot 2024-12-02 at 5 52 32 PM" src="https://github.com/user-attachments/assets/b548eb90-7a9b-4edc-a858-ee237cc4361d"> client <img width="807" alt="Screenshot 2024-12-02 at 5 52 37 PM" src="https://github.com/user-attachments/assets/1138ffa8-4698-40fa-b190-3d7b99646838"> **parity check** torchtune dataloader output and llama-stack post training dataloader output are same <img width="1116" alt="Screenshot 2024-12-04 at 8 18 46 PM" src="https://github.com/user-attachments/assets/5e295cdc-4c24-4ea6-82c0-ca96ef1bd6ee"> torchtune LoRA SFT and llama-stack post training LoRA SFT on alpaca dataset with llama3.2 3B instruct model are numerical match <img width="860" alt="Screenshot 2024-12-04 at 8 17 01 PM" src="https://github.com/user-attachments/assets/c05cf0a8-c674-4d2e-9f0a-c5d01b2dca99"> <img width="1049" alt="Screenshot 2024-12-04 at 8 17 06 PM" src="https://github.com/user-attachments/assets/b911d4e2-e7b1-41a9-b62c-d75529b6d443"> **unit test ** ![Uploading Screenshot 2024-12-09 at 1.35.10 PM.png…]() |
||
---|---|---|
.. | ||
agents | ||
datasetio | ||
eval | ||
inference | ||
memory | ||
post_training | ||
safety | ||
scoring | ||
__init__.py | ||
conftest.py | ||
env.py | ||
README.md | ||
resolver.py |
Testing Llama Stack Providers
The Llama Stack is designed as a collection of Lego blocks -- various APIs -- which are composable and can be used to quickly and reliably build an app. We need a testing setup which is relatively flexible to enable easy combinations of these providers.
We use pytest
and all of its dynamism to enable the features needed. Specifically:
-
We use
pytest_addoption
to add CLI options allowing you to override providers, models, etc. -
We use
pytest_generate_tests
to dynamically parametrize our tests. This allows us to support a default set of (providers, models, etc.) combinations but retain the flexibility to override them via the CLI if needed. -
We use
pytest_configure
to make sure we dynamically add appropriate marks based on the fixtures we make.
Common options
All tests support a --providers
option which can be a string of the form api1=provider_fixture1,api2=provider_fixture2
. So, when testing safety (which need inference and safety APIs) you can use --providers inference=together,safety=meta_reference
to use these fixtures in concert.
Depending on the API, there are custom options enabled. For example, inference
tests allow for an --inference-model
override, etc.
By default, we disable warnings and enable short tracebacks. You can override them using pytest's flags as appropriate.
Some providers need special API keys or other configuration options to work. You can check out the individual fixtures (located in tests/<api>/fixtures.py
) for what these keys are. These can be specified using the --env
CLI option. You can also have it be present in the environment (exporting in your shell) or put it in the .env
file in the directory from which you run the test. For example, to use the Together fixture you can use --env TOGETHER_API_KEY=<...>
Inference
We have the following orthogonal parametrizations (pytest "marks") for inference tests:
- providers: (meta_reference, together, fireworks, ollama)
- models: (llama_8b, llama_3b)
If you want to run a test with the llama_8b model with fireworks, you can use:
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m "fireworks and llama_8b" \
--env FIREWORKS_API_KEY=<...>
You can make it more complex to run both llama_8b and llama_3b on Fireworks, but only llama_3b with Ollama:
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m "fireworks or (ollama and llama_3b)" \
--env FIREWORKS_API_KEY=<...>
Finally, you can override the model completely by doing:
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m fireworks \
--inference-model "meta-llama/Llama3.1-70B-Instruct" \
--env FIREWORKS_API_KEY=<...>
Agents
The Agents API composes three other APIs underneath:
- Inference
- Safety
- Memory
Given that each of these has several fixtures each, the set of combinations is large. We provide a default set of combinations (see tests/agents/conftest.py
) with easy to use "marks":
meta_reference
-- uses all themeta_reference
fixtures for the dependent APIstogether
-- uses Together for inference, andmeta_reference
for the restollama
-- uses Ollama for inference, andmeta_reference
for the rest
An example test with Together:
pytest -s -m together llama_stack/providers/tests/agents/test_agents.py \
--env TOGETHER_API_KEY=<...>
If you want to override the inference model or safety model used, you can use the --inference-model
or --safety-shield
CLI options as appropriate.
If you wanted to test a remotely hosted stack, you can use -m remote
as follows:
pytest -s -m remote llama_stack/providers/tests/agents/test_agents.py \
--env REMOTE_STACK_URL=<...>