temp commit

This commit is contained in:
Botao Chen 2024-12-16 21:43:30 -08:00
parent 30f6eb282f
commit 81e1957446
10 changed files with 54 additions and 39 deletions

0
=0.10.9 Normal file
View file

0
==0.10.9 Normal file
View file

View file

@ -46,8 +46,10 @@ class MetaReferenceInferenceConfig(BaseModel):
return model return model
@property @property
def model_parallel_size(self) -> int: def model_parallel_size(self) -> Optional[int]:
resolved = resolve_model(self.model) resolved = resolve_model(self.model)
if resolved is None:
return None
return resolved.pth_file_count return resolved.pth_file_count
@classmethod @classmethod

View file

@ -25,12 +25,12 @@ from fairscale.nn.model_parallel.initialize import (
) )
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import ( from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer, CrossAttentionTransformer,
) )
from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -53,16 +53,16 @@ from .config import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths): if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original" checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), ( assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"Please download model using `llama download --model-id {model.descriptor()}`" f"Please download model using `llama download --model-id {model_id}`"
) )
return str(checkpoint_dir) return str(checkpoint_dir)
@ -80,6 +80,7 @@ class Llama:
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
], ],
model_id: str, model_id: str,
llama_model: Model,
): ):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -88,14 +89,16 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
model = resolve_model(model_id) llama_model_id = llama_model.core_model_id.value
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
print("I reach torch.distributed.init_process_group")
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
model_parallel_size = config.model_parallel_size model_parallel_size = (
config.model_parallel_size
if config.model_parallel_size
else llama_model.pth_file_count
)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size) initialize_model_parallel(model_parallel_size)
@ -103,6 +106,8 @@ class Llama:
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
print("torch.cuda.set_device")
# seed must be the same in all processes # seed must be the same in all processes
if config.torch_seed is not None: if config.torch_seed is not None:
torch.manual_seed(config.torch_seed) torch.manual_seed(config.torch_seed)
@ -114,7 +119,7 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null": if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir ckpt_dir = config.checkpoint_dir
else: else:
ckpt_dir = model_checkpoint_dir(model) ckpt_dir = model_checkpoint_dir(model_id) # true model id
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -190,7 +195,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds") log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model) return Llama(model, tokenizer, model_args, llama_model_id)
def __init__( def __init__(
self, self,

View file

@ -46,13 +46,16 @@ class MetaReferenceInferenceImpl(
self.config = config self.config = config
self.model = None self.model = None
async def initialize(self, model_id) -> None: async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config, model_id) print("I reach create_distributed_process_group")
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator.start() self.generator.start()
else: else:
self.generator = Llama.build(self.config, model_id) self.generator = Llama.build(self.config, model_id, llama_model)
self.model = model_id self.model = model_id
@ -65,26 +68,27 @@ class MetaReferenceInferenceImpl(
raise RuntimeError( raise RuntimeError(
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first" "Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
) )
inference_model = resolve_model(self.model) if request.model is None:
requested_model = resolve_model(request.model)
if requested_model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`" f"Unknown model: {request.model}, Run `llama model list`"
) )
elif requested_model.descriptor() != inference_model.descriptor(): elif request.model != self.model:
raise RuntimeError( raise RuntimeError(f"Model mismatch: {request.model} != {self.model}")
f"Model mismatch: {request.model} != {inference_model.descriptor()}"
)
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
llama_model = resolve_model(model.identifier) llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
else resolve_model(model.identifier)
)
if llama_model is None: if llama_model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list" "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
) )
self.model_registry_helper = ModelRegistryHelper( self.model_registry_helper = ModelRegistryHelper(
[ [
build_model_alias( build_model_alias(
@ -94,6 +98,7 @@ class MetaReferenceInferenceImpl(
], ],
) )
model = await self.model_registry_helper.register_model(model) model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
@ -103,7 +108,7 @@ class MetaReferenceInferenceImpl(
and model.metadata["skip_initialize"] and model.metadata["skip_initialize"]
): ):
return model return model
await self.initialize(model.identifier) await self.initialize(model.identifier, llama_model)
return model return model
async def completion( async def completion(

View file

@ -10,8 +10,8 @@ from functools import partial
from typing import Any, Generator from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
@ -37,8 +37,9 @@ class ModelRunner:
def init_model_cb( def init_model_cb(
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
model_id: str, model_id: str,
llama_model: Model,
): ):
llama = Llama.build(config, model_id) llama = Llama.build(config, model_id, llama_model)
return ModelRunner(llama) return ModelRunner(llama)
@ -57,14 +58,15 @@ class LlamaModelParallelGenerator:
self, self,
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
model_id: str, model_id: str,
llama_model: Model,
): ):
self.config = config self.config = config
self.model_id = model_id self.model_id = model_id
self.model = resolve_model(model_id) self.llama_model = llama_model
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going # while the tool-use loop is going
checkpoint_dir = model_checkpoint_dir(self.model) checkpoint_dir = model_checkpoint_dir(self.model_id)
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path)) self.formatter = ChatFormat(Tokenizer(tokenizer_path))
@ -78,11 +80,15 @@ class LlamaModelParallelGenerator:
if self.config.model_parallel_size: if self.config.model_parallel_size:
model_parallel_size = self.config.model_parallel_size model_parallel_size = self.config.model_parallel_size
else: else:
model_parallel_size = resolve_model(self.model).pth_file_count model_parallel_size = self.llama_model.pth_file_count
print(f"model_parallel_size: {model_parallel_size}")
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
model_parallel_size, model_parallel_size,
init_model_cb=partial(init_model_cb, self.config, self.model_id), init_model_cb=partial(
init_model_cb, self.config, self.model_id, self.llama_model
),
) )
self.group.start() self.group.start()
return self return self

View file

@ -300,7 +300,7 @@ def start_model_parallel_process(
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT) main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
ctx = multiprocessing.get_context("fork") ctx = multiprocessing.get_context("spawn")
process = ctx.Process( process = ctx.Process(
target=launch_dist_group, target=launch_dist_group,
args=( args=(

View file

@ -69,7 +69,6 @@ def pytest_generate_tests(metafunc):
else: else:
params = MODEL_PARAMS params = MODEL_PARAMS
# print("params", params)
metafunc.parametrize( metafunc.parametrize(
"inference_model", "inference_model",
params, params,
@ -83,7 +82,5 @@ def pytest_generate_tests(metafunc):
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
}, },
): ):
# print("I reach here")
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
print("fixtures", fixtures)
metafunc.parametrize("inference_stack", fixtures, indirect=True) metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -9,8 +9,8 @@ import pytest
# How to run this test: # How to run this test:
# #
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py # torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
# -m "meta_reference" # ./llama_stack/providers/tests/inference/test_model_registration.py
class TestModelRegistration: class TestModelRegistration:

View file

@ -75,7 +75,7 @@ metadata_store:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models: [] models:
- metadata: {} - metadata: {}
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference