diff --git a/=0.10.9 b/=0.10.9 new file mode 100644 index 000000000..e69de29bb diff --git a/==0.10.9 b/==0.10.9 new file mode 100644 index 000000000..e69de29bb diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index ffc27c08c..dc54df2b8 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -46,8 +46,10 @@ class MetaReferenceInferenceConfig(BaseModel): return model @property - def model_parallel_size(self) -> int: + def model_parallel_size(self) -> Optional[int]: resolved = resolve_model(self.model) + if resolved is None: + return None return resolved.pth_file_count @classmethod diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index eebe7b61d..9bb1bbdaf 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -25,12 +25,12 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, ModelInput +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) -from llama_models.sku_list import resolve_model from pydantic import BaseModel from llama_stack.apis.inference import * # noqa: F403 @@ -53,16 +53,16 @@ from .config import ( log = logging.getLogger(__name__) -def model_checkpoint_dir(model) -> str: - checkpoint_dir = Path(model_local_dir(model.descriptor())) +def model_checkpoint_dir(model_id) -> str: + checkpoint_dir = Path(model_local_dir(model_id)) paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] if not any(p.exists() for p in paths): checkpoint_dir = checkpoint_dir / "original" assert checkpoint_dir.exists(), ( - f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " - f"Please download model using `llama download --model-id {model.descriptor()}`" + f"Could not find checkpoints in: {model_local_dir(model_id)}. " + f"Please download model using `llama download --model-id {model_id}`" ) return str(checkpoint_dir) @@ -80,6 +80,7 @@ class Llama: MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig ], model_id: str, + llama_model: Model, ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -88,14 +89,16 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - model = resolve_model(model_id) - - llama_model = model.core_model_id.value - + llama_model_id = llama_model.core_model_id.value if not torch.distributed.is_initialized(): + print("I reach torch.distributed.init_process_group") torch.distributed.init_process_group("nccl") - model_parallel_size = config.model_parallel_size + model_parallel_size = ( + config.model_parallel_size + if config.model_parallel_size + else llama_model.pth_file_count + ) if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) @@ -103,6 +106,8 @@ class Llama: local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) + print("torch.cuda.set_device") + # seed must be the same in all processes if config.torch_seed is not None: torch.manual_seed(config.torch_seed) @@ -114,7 +119,7 @@ class Llama: if config.checkpoint_dir and config.checkpoint_dir != "null": ckpt_dir = config.checkpoint_dir else: - ckpt_dir = model_checkpoint_dir(model) + ckpt_dir = model_checkpoint_dir(model_id) # true model id checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -190,7 +195,7 @@ class Llama: model.load_state_dict(state_dict, strict=False) log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args, llama_model) + return Llama(model, tokenizer, model_args, llama_model_id) def __init__( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index d86eba797..930032ceb 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -46,13 +46,16 @@ class MetaReferenceInferenceImpl( self.config = config self.model = None - async def initialize(self, model_id) -> None: + async def initialize(self, model_id, llama_model) -> None: log.info(f"Loading model `{model_id}`") if self.config.create_distributed_process_group: - self.generator = LlamaModelParallelGenerator(self.config, model_id) + print("I reach create_distributed_process_group") + self.generator = LlamaModelParallelGenerator( + self.config, model_id, llama_model + ) self.generator.start() else: - self.generator = Llama.build(self.config, model_id) + self.generator = Llama.build(self.config, model_id, llama_model) self.model = model_id @@ -65,26 +68,27 @@ class MetaReferenceInferenceImpl( raise RuntimeError( "Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first" ) - inference_model = resolve_model(self.model) - requested_model = resolve_model(request.model) - if requested_model is None: + if request.model is None: raise RuntimeError( f"Unknown model: {request.model}, Run `llama model list`" ) - elif requested_model.descriptor() != inference_model.descriptor(): - raise RuntimeError( - f"Model mismatch: {request.model} != {inference_model.descriptor()}" - ) + elif request.model != self.model: + raise RuntimeError(f"Model mismatch: {request.model} != {self.model}") async def unregister_model(self, model_id: str) -> None: pass async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: - llama_model = resolve_model(model.identifier) + llama_model = ( + resolve_model(model.metadata["llama_model"]) + if "llama_model" in model.metadata + else resolve_model(model.identifier) + ) if llama_model is None: raise RuntimeError( - f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list" + "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" ) + self.model_registry_helper = ModelRegistryHelper( [ build_model_alias( @@ -94,6 +98,7 @@ class MetaReferenceInferenceImpl( ], ) model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding: self._load_sentence_transformer_model(model.provider_resource_id) @@ -103,7 +108,7 @@ class MetaReferenceInferenceImpl( and model.metadata["skip_initialize"] ): return model - await self.initialize(model.identifier) + await self.initialize(model.identifier, llama_model) return model async def completion( diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index f5d5cc567..d9116475f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,8 +10,8 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest @@ -37,8 +37,9 @@ class ModelRunner: def init_model_cb( config: MetaReferenceInferenceConfig, model_id: str, + llama_model: Model, ): - llama = Llama.build(config, model_id) + llama = Llama.build(config, model_id, llama_model) return ModelRunner(llama) @@ -57,14 +58,15 @@ class LlamaModelParallelGenerator: self, config: MetaReferenceInferenceConfig, model_id: str, + llama_model: Model, ): self.config = config self.model_id = model_id - self.model = resolve_model(model_id) + self.llama_model = llama_model # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - checkpoint_dir = model_checkpoint_dir(self.model) + checkpoint_dir = model_checkpoint_dir(self.model_id) tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") self.formatter = ChatFormat(Tokenizer(tokenizer_path)) @@ -78,11 +80,15 @@ class LlamaModelParallelGenerator: if self.config.model_parallel_size: model_parallel_size = self.config.model_parallel_size else: - model_parallel_size = resolve_model(self.model).pth_file_count + model_parallel_size = self.llama_model.pth_file_count + + print(f"model_parallel_size: {model_parallel_size}") self.group = ModelParallelProcessGroup( model_parallel_size, - init_model_cb=partial(init_model_cb, self.config, self.model_id), + init_model_cb=partial( + init_model_cb, self.config, self.model_id, self.llama_model + ), ) self.group.start() return self diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 076e39729..830160578 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -300,7 +300,7 @@ def start_model_parallel_process( main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT) - ctx = multiprocessing.get_context("fork") + ctx = multiprocessing.get_context("spawn") process = ctx.Process( target=launch_dist_group, args=( diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 9f6bf5d67..54ebcd83a 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -69,7 +69,6 @@ def pytest_generate_tests(metafunc): else: params = MODEL_PARAMS - # print("params", params) metafunc.parametrize( "inference_model", params, @@ -83,7 +82,5 @@ def pytest_generate_tests(metafunc): "inference": INFERENCE_FIXTURES, }, ): - # print("I reach here") fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] - print("fixtures", fixtures) metafunc.parametrize("inference_stack", fixtures, indirect=True) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 1471bc369..74076fb28 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -9,8 +9,8 @@ import pytest # How to run this test: # -# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py -# -m "meta_reference" +# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" +# ./llama_stack/providers/tests/inference/test_model_registration.py class TestModelRegistration: diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index c148e8108..0763d0c36 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -75,7 +75,7 @@ metadata_store: namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -models: [] +models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference