diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 7a974b917..366a2a0f2 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -1,9 +1,9 @@ { - "hf-serverless": [ - "aiohttp", + "bedrock": [ "aiosqlite", "autoevals", "blobfile", + "boto3", "chardet", "chromadb-client", "datasets", @@ -11,100 +11,6 @@ "fastapi", "fire", "httpx", - "huggingface_hub", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "together": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "together", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "vllm-gpu": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", "matplotlib", "nltk", "numpy", @@ -157,7 +63,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "tgi": [ + "hf-endpoint": [ "aiohttp", "aiosqlite", "autoevals", @@ -190,11 +96,11 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "bedrock": [ + "hf-serverless": [ + "aiohttp", "aiosqlite", "autoevals", "blobfile", - "boto3", "chardet", "chromadb-client", "datasets", @@ -202,6 +108,7 @@ "fastapi", "fire", "httpx", + "huggingface_hub", "matplotlib", "nltk", "numpy", @@ -300,34 +207,6 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "cerebras": [ - "aiosqlite", - "blobfile", - "cerebras_cloud_sdk", - "chardet", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], "ollama": [ "aiohttp", "aiosqlite", @@ -361,7 +240,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "hf-endpoint": [ + "tgi": [ "aiohttp", "aiosqlite", "autoevals", @@ -393,5 +272,126 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "together": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "together", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "remote-vllm": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "vllm-gpu": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "vllm", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "cerebras": [ + "aiosqlite", + "blobfile", + "cerebras_cloud_sdk", + "chardet", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 04058d55d..33af33fcd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -7,19 +7,19 @@ from typing import Any, Dict, Optional from llama_models.datatypes import * # noqa: F403 -from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F401, F403 -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, field_validator from llama_stack.providers.utils.inference import supported_inference_models class MetaReferenceInferenceConfig(BaseModel): - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", - ) + # this is a placeholder to indicate inference model id + # the actual inference model id is dtermined by the moddel id in the request + # Note: you need to register the model before using it for inference + # models in the resouce list in the run.yaml config will be registered automatically + model: Optional[str] = None torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 @@ -46,11 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel): ) return model - @property - def model_parallel_size(self) -> int: - resolved = resolve_model(self.model) - return resolved.pth_file_count - @classmethod def sample_run_config( cls, diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 5ea7e1ad5..c89183cb7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -25,6 +25,7 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -53,16 +54,17 @@ from .config import ( log = logging.getLogger(__name__) -def model_checkpoint_dir(model) -> str: - checkpoint_dir = Path(model_local_dir(model.descriptor())) +def model_checkpoint_dir(model_id) -> str: + checkpoint_dir = Path(model_local_dir(model_id)) paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] if not any(p.exists() for p in paths): checkpoint_dir = checkpoint_dir / "original" assert checkpoint_dir.exists(), ( - f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " - f"Please download model using `llama download --model-id {model.descriptor()}`" + f"Could not find checkpoints in: {model_local_dir(model_id)}. " + f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`" + f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}" ) return str(checkpoint_dir) @@ -79,6 +81,8 @@ class Llama: config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig ], + model_id: str, + llama_model: Model, ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -87,13 +91,11 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - model = resolve_model(config.model) - llama_model = model.core_model_id.value - + llama_model_id = llama_model.core_model_id.value if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = config.model_parallel_size + model_parallel_size = llama_model.pth_file_count if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) @@ -112,7 +114,13 @@ class Llama: if config.checkpoint_dir and config.checkpoint_dir != "null": ckpt_dir = config.checkpoint_dir else: - ckpt_dir = model_checkpoint_dir(model) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + ckpt_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -188,7 +196,7 @@ class Llama: model.load_state_dict(state_dict, strict=False) log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args, llama_model) + return Llama(model, tokenizer, model_args, llama_model_id) def __init__( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 92d96ab65..d89bb21f7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -9,8 +9,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.datatypes import Model - from llama_models.llama3.api.datatypes import ( SamplingParams, StopReason, @@ -40,7 +38,7 @@ from llama_stack.apis.inference import ( ToolChoice, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -54,6 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, convert_request_to_raw, ) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -71,50 +70,69 @@ class MetaReferenceInferenceImpl( ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config - model = resolve_model(config.model) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") - self.model_registry_helper = ModelRegistryHelper( - [ - build_model_alias( - model.descriptor(), - model.core_model_id.value, - ) - ], - ) - self.model = model - # verify that the checkpoint actually is for this model lol + self.model_id = None + self.llama_model = None async def initialize(self) -> None: - log.info(f"Loading model `{self.model.descriptor()}`") + pass + + async def load_model(self, model_id, llama_model) -> None: + log.info(f"Loading model `{model_id}`") if self.config.create_distributed_process_group: - self.generator = LlamaModelParallelGenerator(self.config) + self.generator = LlamaModelParallelGenerator( + self.config, model_id, llama_model + ) self.generator.start() else: - self.generator = Llama.build(self.config) + self.generator = Llama.build(self.config, model_id, llama_model) + + self.model_id = model_id + self.llama_model = llama_model async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() def check_model(self, request) -> None: - model = resolve_model(request.model) - if model is None: + if self.model_id is None or self.llama_model is None: raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" + "No avaible model yet, please register your requested model or add your model in the resouces first" ) - elif model.descriptor() != self.model.descriptor(): + elif request.model != self.model_id: raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" + f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}" ) async def unregister_model(self, model_id: str) -> None: pass async def register_model(self, model: Model) -> Model: + llama_model = ( + resolve_model(model.metadata["llama_model"]) + if "llama_model" in model.metadata + else resolve_model(model.identifier) + ) + if llama_model is None: + raise ValueError( + "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" + ) + + self.model_registry_helper = ModelRegistryHelper( + [ + build_model_alias( + llama_model.descriptor(), + llama_model.core_model_id.value, + ) + ], + ) model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding: self._load_sentence_transformer_model(model.provider_resource_id) + + if "skip_load" in model.metadata and model.metadata["skip_load"]: + return model + await self.load_model(model.identifier, llama_model) return model async def completion( @@ -267,7 +285,7 @@ class MetaReferenceInferenceImpl( # augment and rewrite messages depending on the model request.messages = chat_completion_request_to_messages( - request, self.model.core_model_id.value + request, self.llama_model.core_model_id.value ) # download media and convert to raw content so we can send it to the model request = await convert_request_to_raw(request) diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 7e7831185..cb422b9b6 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model @@ -34,8 +35,12 @@ class ModelRunner: raise ValueError(f"Unexpected task type {type(req)}") -def init_model_cb(config: MetaReferenceInferenceConfig): - llama = Llama.build(config) +def init_model_cb( + config: MetaReferenceInferenceConfig, + model_id: str, + llama_model: Model, +): + llama = Llama.build(config, model_id, llama_model) return ModelRunner(llama) @@ -50,12 +55,25 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, config: MetaReferenceInferenceConfig): + def __init__( + self, + config: MetaReferenceInferenceConfig, + model_id: str, + llama_model: Model, + ): self.config = config - self.model = resolve_model(self.config.model) + self.model_id = model_id + self.llama_model = llama_model + # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - checkpoint_dir = model_checkpoint_dir(self.model) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + checkpoint_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor()) tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") self.formatter = ChatFormat(Tokenizer(tokenizer_path)) @@ -66,9 +84,13 @@ class LlamaModelParallelGenerator: self.__exit__(None, None, None) def __enter__(self): + model_parallel_size = self.llama_model.pth_file_count + self.group = ModelParallelProcessGroup( - self.config.model_parallel_size, - init_model_cb=partial(init_model_cb, self.config), + model_parallel_size, + init_model_cb=partial( + init_model_cb, self.config, self.model_id, self.llama_model + ), ) self.group.start() return self diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 076e39729..830160578 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -300,7 +300,7 @@ def start_model_parallel_process( main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT) - ctx = multiprocessing.get_context("fork") + ctx = multiprocessing.get_context("spawn") process = ctx.Process( target=launch_dist_group, args=( diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 1471bc369..3cd7b2496 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,13 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from unittest.mock import AsyncMock, patch + import pytest # How to run this test: # -# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py -# -m "meta_reference" +# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" +# ./llama_stack/providers/tests/inference/test_model_registration.py class TestModelRegistration: @@ -51,16 +53,37 @@ class TestModelRegistration: _ = await models_impl.register_model( model_id="custom-model", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + "skip_load": True, + }, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(AssertionError) as exc_info: await models_impl.register_model( model_id="custom-model-2", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + }, provider_model_id="custom-model", ) + @pytest.mark.asyncio + async def test_initialize_model_during_registering(self, inference_stack): + _, models_impl = inference_stack + + with patch( + "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", + new_callable=AsyncMock, + ) as mock_load_model: + _ = await models_impl.register_model( + model_id="Llama3.1-8B-Instruct", + metadata={ + "llama_model": "meta-llama/Llama-3.1-8B-Instruct", + }, + ) + mock_load_model.assert_called_once() + @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 4bdde7aa6..113c3a793 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -3,10 +3,17 @@ image_name: experimental-post-training docker_image: null conda_env: experimental-post-training apis: +- inference - telemetry - datasetio - post_training providers: + inference: + - provider_id: meta-reference-inference + provider_type: inline::meta-reference + config: + max_seq_len: 4096 + checkpoint_dir: null datasetio: - provider_id: huggingface-0 provider_type: remote::huggingface @@ -24,11 +31,7 @@ metadata_store: namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -models: -- metadata: {} - model_id: ${env.POST_TRAINING_MODEL} - provider_id: meta-reference-inference - provider_model_id: null +models: [] shields: [] memory_banks: [] datasets: