[torchtune integration] post training + eval (#670)

## What does this PR do?

- Add related Apis in experimental-post-training template to enable eval
on the finetuned checkpoint in the template
- A small bug fix on meta reference eval
- A small error handle improvement on post training 


## Test Plan
From client side issued an E2E post training request
https://github.com/meta-llama/llama-stack-client-python/pull/70 and get
eval results successfully

<img width="1315" alt="Screenshot 2024-12-20 at 12 06 59 PM"
src="https://github.com/user-attachments/assets/a09bd524-59ae-490c-908f-2e36ccf27c0a"
/>
This commit is contained in:
Botao Chen 2024-12-20 13:43:13 -08:00 committed by GitHub
parent c8be0bf1c9
commit 06cb0c837e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 52 additions and 3 deletions

View file

@ -15,7 +15,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -110,6 +110,10 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)

View file

@ -4,10 +4,22 @@ distribution_spec:
description: Experimental template for post training
docker_image: null
providers:
inference:
- inline::meta-reference
eval:
- inline::meta-reference
scoring:
- inline::basic
post_training:
- inline::torchtune
datasetio:
- remote::huggingface
telemetry:
- inline::meta-reference
agents:
- inline::meta-reference
safety:
- inline::llama-guard
memory:
- inline::faiss
image_type: conda

View file

@ -3,9 +3,14 @@ image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- inference
- telemetry
- agents
- datasetio
- eval
- inference
- memory
- safety
- scoring
- telemetry
- post_training
providers:
inference:
@ -14,6 +19,14 @@ providers:
config:
max_seq_len: 4096
checkpoint_dir: null
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
@ -26,6 +39,26 @@ providers:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
metadata_store:
namespace: null