[4/n][torchtune integration] support lazy load model during inference (#620)

## What does this PR do?
In this PR, we refactor the meta reference inference logic to support 
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model

## Why need these changes
To solve the existing pain points that 
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models

## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)


## test
run python llama_stack/scripts/distro_codegen.py


**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py


**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded

<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded 
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>


<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>

if add "skip_initialize" in metadata, model is registered but isn't
loaded

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"

Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>

**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml 
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully 
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>



## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully 
- the model is shown up in the model list 
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>

**run inference**

<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
This commit is contained in:
Botao Chen 2024-12-18 16:30:53 -08:00 committed by GitHub
parent 3b4b2ea30c
commit 36b4fe02cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 261 additions and 192 deletions

View file

@ -1,9 +1,9 @@
{
"hf-serverless": [
"aiohttp",
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
@ -11,100 +11,6 @@
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
@ -157,7 +63,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
@ -190,11 +96,11 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"bedrock": [
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
@ -202,6 +108,7 @@
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
@ -300,34 +207,6 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [
"aiohttp",
"aiosqlite",
@ -361,7 +240,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"hf-endpoint": [
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
@ -393,5 +272,126 @@
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -7,19 +7,19 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
# this is a placeholder to indicate inference model id
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -46,11 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel):
)
return model
@property
def model_parallel_size(self) -> int:
resolved = resolve_model(self.model)
return resolved.pth_file_count
@classmethod
def sample_run_config(
cls,

View file

@ -25,6 +25,7 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -53,16 +54,17 @@ from .config import (
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
@ -79,6 +81,8 @@ class Llama:
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
model_id: str,
llama_model: Model,
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -87,13 +91,11 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = config.model_parallel_size
model_parallel_size = llama_model.pth_file_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -112,7 +114,13 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -188,7 +196,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
return Llama(model, tokenizer, model_args, llama_model_id)
def __init__(
self,

View file

@ -9,8 +9,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
@ -40,7 +38,7 @@ from llama_stack.apis.inference import (
ToolChoice,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -54,6 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
@ -71,50 +70,69 @@ class MetaReferenceInferenceImpl(
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
self.model = model
# verify that the checkpoint actually is for this model lol
self.model_id = None
self.llama_model = None
async def initialize(self) -> None:
log.info(f"Loading model `{self.model.descriptor()}`")
pass
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator.start()
else:
self.generator = Llama.build(self.config)
self.generator = Llama.build(self.config, model_id, llama_model)
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif model.descriptor() != self.model.descriptor():
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
else resolve_model(model.identifier)
)
if llama_model is None:
raise ValueError(
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
llama_model.descriptor(),
llama_model.core_model_id.value,
)
],
)
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model
await self.load_model(model.identifier, llama_model)
return model
async def completion(
@ -267,7 +285,7 @@ class MetaReferenceInferenceImpl(
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.model.core_model_id.value
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)

View file

@ -10,6 +10,7 @@ from functools import partial
from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -34,8 +35,12 @@ class ModelRunner:
raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config)
def init_model_cb(
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
llama = Llama.build(config, model_id, llama_model)
return ModelRunner(llama)
@ -50,12 +55,25 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceInferenceConfig):
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
self.config = config
self.model = resolve_model(self.config.model)
self.model_id = model_id
self.llama_model = llama_model
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
checkpoint_dir = model_checkpoint_dir(self.model)
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
checkpoint_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
@ -66,9 +84,13 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
model_parallel_size = self.llama_model.pth_file_count
self.group = ModelParallelProcessGroup(
self.config.model_parallel_size,
init_model_cb=partial(init_model_cb, self.config),
model_parallel_size,
init_model_cb=partial(
init_model_cb, self.config, self.model_id, self.llama_model
),
)
self.group.start()
return self

View file

@ -300,7 +300,7 @@ def start_model_parallel_process(
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork")
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(
target=launch_dist_group,
args=(

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
# -m "meta_reference"
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
# ./llama_stack/providers/tests/inference/test_model_registration.py
class TestModelRegistration:
@ -51,16 +53,37 @@ class TestModelRegistration:
_ = await models_impl.register_model(
model_id="custom-model",
metadata={"llama_model": "meta-llama/Llama-2-7b"},
metadata={
"llama_model": "meta-llama/Llama-2-7b",
"skip_load": True,
},
)
with pytest.raises(ValueError) as exc_info:
with pytest.raises(AssertionError) as exc_info:
await models_impl.register_model(
model_id="custom-model-2",
metadata={"llama_model": "meta-llama/Llama-2-7b"},
metadata={
"llama_model": "meta-llama/Llama-2-7b",
},
provider_model_id="custom-model",
)
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock,
) as mock_load_model:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_load_model.assert_called_once()
@pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack

View file

@ -3,10 +3,17 @@ image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- inference
- telemetry
- datasetio
- post_training
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
max_seq_len: 4096
checkpoint_dir: null
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
@ -24,11 +31,7 @@ metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.POST_TRAINING_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
models: []
shields: []
memory_banks: []
datasets: