[4/n][torchtune integration] support lazy load model during inference (#620)

## What does this PR do?
In this PR, we refactor the meta reference inference logic to support 
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model

## Why need these changes
To solve the existing pain points that 
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models

## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)


## test
run python llama_stack/scripts/distro_codegen.py


**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py


**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded

<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded 
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>


<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>

if add "skip_initialize" in metadata, model is registered but isn't
loaded

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"

Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>

**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml 
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully 
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>



## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully 
- the model is shown up in the model list 
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>

**run inference**

<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
This commit is contained in:
Botao Chen 2024-12-18 16:30:53 -08:00 committed by GitHub
parent 3b4b2ea30c
commit 36b4fe02cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 261 additions and 192 deletions

View file

@ -25,6 +25,7 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -53,16 +54,17 @@ from .config import (
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
@ -79,6 +81,8 @@ class Llama:
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
model_id: str,
llama_model: Model,
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -87,13 +91,11 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = config.model_parallel_size
model_parallel_size = llama_model.pth_file_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -112,7 +114,13 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -188,7 +196,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
return Llama(model, tokenizer, model_args, llama_model_id)
def __init__(
self,