mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
[4/n][torchtune integration] support lazy load model during inference (#620)
## What does this PR do?
In this PR, we refactor the meta reference inference logic to support
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model
## Why need these changes
To solve the existing pain points that
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models
## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)
## test
run python llama_stack/scripts/distro_codegen.py
**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py
**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded
<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>
<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>
if add "skip_initialize" in metadata, model is registered but isn't
loaded
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>
**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>
## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully
- the model is shown up in the model list
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>
**run inference**
<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
This commit is contained in:
parent
3b4b2ea30c
commit
36b4fe02cc
8 changed files with 261 additions and 192 deletions
|
@ -9,8 +9,6 @@ import logging
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import Model
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
|
@ -40,7 +38,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
)
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -54,6 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_messages,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
@ -71,50 +70,69 @@ class MetaReferenceInferenceImpl(
|
|||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
self.model_id = None
|
||||
self.llama_model = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||
pass
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
self.config, model_id, llama_model
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
||||
)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
else resolve_model(model.identifier)
|
||||
)
|
||||
if llama_model is None:
|
||||
raise ValueError(
|
||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
|
@ -267,7 +285,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.model.core_model_id.value
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue