From 36b4fe02ccddcfd3f0aff82c08c51974436b4a8e Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 18 Dec 2024 16:30:53 -0800 Subject: [PATCH] [4/n][torchtune integration] support lazy load model during inference (#620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? In this PR, we refactor the meta reference inference logic to support - load the model during registering model instead of during spinning up server - support inference finetuned model checkpoint on top of native llama model ## Why need these changes To solve the existing pain points that - user cannot lazy load the model and hot switch the inference checkpoint after spinning up the server - this blocks us doing inference and eval on the same sever for a finetuned checkpoint after post training - user cannot do inference on a finetuned checkpoint on top of native llama models ## Expect user experience change - The inference model won't be loaded when spinning up server. Instead, it will be loaded during register model. If user add the model as models resource in run.yaml, it will be registered and loaded automatically when starting server. There is an optional flag 'skip_initialize' in model metadata to skip model loading during registration. - There is an optional flag 'llama_model' in model metadata to identify the base model of the Model class for validation and initialize model arch. model identifier no longer needs to be a native llama model - the default inference model name updates from 'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct' - It aligns with the checkpoint folder name after running 'llama model download' - It aligns with the descriptor name defined in llama-models SKU list https://github.com/meta-llama/llama-models/blob/bf5b0c4fe74e3b51ed5904ab65e3f671b194d2a9/models/datatypes.py#L95 ## test run python llama_stack/scripts/distro_codegen.py **run unit test** - torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py - torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_model_registration.py **test post training experience** on server side run: llama stack run llama_stack/templates/experimental-post-training/run.yaml server is spinning up without model loaded Screenshot 2024-12-17 at 1 24 50 PM on client side, run: llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 models register Llama3.2-3B-Instruct register model successfully and the model is loaded Screenshot 2024-12-17 at 1 26 30 PM Screenshot 2024-12-17 at 1 26 09 PM if add "skip_initialize" in metadata, model is registered but isn't loaded on client side, run: llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 inference chat-completion --message "hello, what model are you?" Inference the model succesfully Screenshot 2024-12-17 at 1 27 33 PM **test inference experience** run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml model is loaded since the model is in resouce list in run.yaml Screenshot 2024-12-17 at 1 30 19 PM on client side, run: llama-stack-client --endpoint http://devgpu018.nha2.facebook.com:5000 inference chat-completion --message "hello, what model are you?" inference successfully Screenshot 2024-12-17 at 1 31 08 PM ## inference on a finetuned model **register a finetuned model that finetuned by post training api (torchtune)** - the model is registered and loaded successfully - the model is shown up in the model list Screenshot 2024-12-18 at 3 56 33 PM **run inference** Screenshot 2024-12-18 at 3 57 59 PM --- distributions/dependencies.json | 256 +++++++++--------- .../inline/inference/meta_reference/config.py | 17 +- .../inference/meta_reference/generation.py | 28 +- .../inference/meta_reference/inference.py | 68 +++-- .../meta_reference/model_parallel.py | 36 ++- .../meta_reference/parallel_utils.py | 2 +- .../inference/test_model_registration.py | 33 ++- .../experimental-post-training/run.yaml | 13 +- 8 files changed, 261 insertions(+), 192 deletions(-) diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 7a974b917..366a2a0f2 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -1,9 +1,9 @@ { - "hf-serverless": [ - "aiohttp", + "bedrock": [ "aiosqlite", "autoevals", "blobfile", + "boto3", "chardet", "chromadb-client", "datasets", @@ -11,100 +11,6 @@ "fastapi", "fire", "httpx", - "huggingface_hub", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "together": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "together", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "vllm-gpu": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", "matplotlib", "nltk", "numpy", @@ -157,7 +63,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "tgi": [ + "hf-endpoint": [ "aiohttp", "aiosqlite", "autoevals", @@ -190,11 +96,11 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "bedrock": [ + "hf-serverless": [ + "aiohttp", "aiosqlite", "autoevals", "blobfile", - "boto3", "chardet", "chromadb-client", "datasets", @@ -202,6 +108,7 @@ "fastapi", "fire", "httpx", + "huggingface_hub", "matplotlib", "nltk", "numpy", @@ -300,34 +207,6 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "cerebras": [ - "aiosqlite", - "blobfile", - "cerebras_cloud_sdk", - "chardet", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], "ollama": [ "aiohttp", "aiosqlite", @@ -361,7 +240,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "hf-endpoint": [ + "tgi": [ "aiohttp", "aiosqlite", "autoevals", @@ -393,5 +272,126 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "together": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "together", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "remote-vllm": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "vllm-gpu": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "vllm", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "cerebras": [ + "aiosqlite", + "blobfile", + "cerebras_cloud_sdk", + "chardet", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 04058d55d..33af33fcd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -7,19 +7,19 @@ from typing import Any, Dict, Optional from llama_models.datatypes import * # noqa: F403 -from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F401, F403 -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, field_validator from llama_stack.providers.utils.inference import supported_inference_models class MetaReferenceInferenceConfig(BaseModel): - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", - ) + # this is a placeholder to indicate inference model id + # the actual inference model id is dtermined by the moddel id in the request + # Note: you need to register the model before using it for inference + # models in the resouce list in the run.yaml config will be registered automatically + model: Optional[str] = None torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 @@ -46,11 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel): ) return model - @property - def model_parallel_size(self) -> int: - resolved = resolve_model(self.model) - return resolved.pth_file_count - @classmethod def sample_run_config( cls, diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 5ea7e1ad5..c89183cb7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -25,6 +25,7 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -53,16 +54,17 @@ from .config import ( log = logging.getLogger(__name__) -def model_checkpoint_dir(model) -> str: - checkpoint_dir = Path(model_local_dir(model.descriptor())) +def model_checkpoint_dir(model_id) -> str: + checkpoint_dir = Path(model_local_dir(model_id)) paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] if not any(p.exists() for p in paths): checkpoint_dir = checkpoint_dir / "original" assert checkpoint_dir.exists(), ( - f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " - f"Please download model using `llama download --model-id {model.descriptor()}`" + f"Could not find checkpoints in: {model_local_dir(model_id)}. " + f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`" + f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}" ) return str(checkpoint_dir) @@ -79,6 +81,8 @@ class Llama: config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig ], + model_id: str, + llama_model: Model, ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -87,13 +91,11 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - model = resolve_model(config.model) - llama_model = model.core_model_id.value - + llama_model_id = llama_model.core_model_id.value if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") - model_parallel_size = config.model_parallel_size + model_parallel_size = llama_model.pth_file_count if not model_parallel_is_initialized(): initialize_model_parallel(model_parallel_size) @@ -112,7 +114,13 @@ class Llama: if config.checkpoint_dir and config.checkpoint_dir != "null": ckpt_dir = config.checkpoint_dir else: - ckpt_dir = model_checkpoint_dir(model) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + ckpt_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -188,7 +196,7 @@ class Llama: model.load_state_dict(state_dict, strict=False) log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args, llama_model) + return Llama(model, tokenizer, model_args, llama_model_id) def __init__( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 92d96ab65..d89bb21f7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -9,8 +9,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.datatypes import Model - from llama_models.llama3.api.datatypes import ( SamplingParams, StopReason, @@ -40,7 +38,7 @@ from llama_stack.apis.inference import ( ToolChoice, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -54,6 +52,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, convert_request_to_raw, ) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -71,50 +70,69 @@ class MetaReferenceInferenceImpl( ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config - model = resolve_model(config.model) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") - self.model_registry_helper = ModelRegistryHelper( - [ - build_model_alias( - model.descriptor(), - model.core_model_id.value, - ) - ], - ) - self.model = model - # verify that the checkpoint actually is for this model lol + self.model_id = None + self.llama_model = None async def initialize(self) -> None: - log.info(f"Loading model `{self.model.descriptor()}`") + pass + + async def load_model(self, model_id, llama_model) -> None: + log.info(f"Loading model `{model_id}`") if self.config.create_distributed_process_group: - self.generator = LlamaModelParallelGenerator(self.config) + self.generator = LlamaModelParallelGenerator( + self.config, model_id, llama_model + ) self.generator.start() else: - self.generator = Llama.build(self.config) + self.generator = Llama.build(self.config, model_id, llama_model) + + self.model_id = model_id + self.llama_model = llama_model async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() def check_model(self, request) -> None: - model = resolve_model(request.model) - if model is None: + if self.model_id is None or self.llama_model is None: raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" + "No avaible model yet, please register your requested model or add your model in the resouces first" ) - elif model.descriptor() != self.model.descriptor(): + elif request.model != self.model_id: raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" + f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}" ) async def unregister_model(self, model_id: str) -> None: pass async def register_model(self, model: Model) -> Model: + llama_model = ( + resolve_model(model.metadata["llama_model"]) + if "llama_model" in model.metadata + else resolve_model(model.identifier) + ) + if llama_model is None: + raise ValueError( + "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" + ) + + self.model_registry_helper = ModelRegistryHelper( + [ + build_model_alias( + llama_model.descriptor(), + llama_model.core_model_id.value, + ) + ], + ) model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding: self._load_sentence_transformer_model(model.provider_resource_id) + + if "skip_load" in model.metadata and model.metadata["skip_load"]: + return model + await self.load_model(model.identifier, llama_model) return model async def completion( @@ -267,7 +285,7 @@ class MetaReferenceInferenceImpl( # augment and rewrite messages depending on the model request.messages = chat_completion_request_to_messages( - request, self.model.core_model_id.value + request, self.llama_model.core_model_id.value ) # download media and convert to raw content so we can send it to the model request = await convert_request_to_raw(request) diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 7e7831185..cb422b9b6 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model @@ -34,8 +35,12 @@ class ModelRunner: raise ValueError(f"Unexpected task type {type(req)}") -def init_model_cb(config: MetaReferenceInferenceConfig): - llama = Llama.build(config) +def init_model_cb( + config: MetaReferenceInferenceConfig, + model_id: str, + llama_model: Model, +): + llama = Llama.build(config, model_id, llama_model) return ModelRunner(llama) @@ -50,12 +55,25 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, config: MetaReferenceInferenceConfig): + def __init__( + self, + config: MetaReferenceInferenceConfig, + model_id: str, + llama_model: Model, + ): self.config = config - self.model = resolve_model(self.config.model) + self.model_id = model_id + self.llama_model = llama_model + # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - checkpoint_dir = model_checkpoint_dir(self.model) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + checkpoint_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor()) tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") self.formatter = ChatFormat(Tokenizer(tokenizer_path)) @@ -66,9 +84,13 @@ class LlamaModelParallelGenerator: self.__exit__(None, None, None) def __enter__(self): + model_parallel_size = self.llama_model.pth_file_count + self.group = ModelParallelProcessGroup( - self.config.model_parallel_size, - init_model_cb=partial(init_model_cb, self.config), + model_parallel_size, + init_model_cb=partial( + init_model_cb, self.config, self.model_id, self.llama_model + ), ) self.group.start() return self diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 076e39729..830160578 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -300,7 +300,7 @@ def start_model_parallel_process( main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT) - ctx = multiprocessing.get_context("fork") + ctx = multiprocessing.get_context("spawn") process = ctx.Process( target=launch_dist_group, args=( diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 1471bc369..3cd7b2496 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,13 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from unittest.mock import AsyncMock, patch + import pytest # How to run this test: # -# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py -# -m "meta_reference" +# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" +# ./llama_stack/providers/tests/inference/test_model_registration.py class TestModelRegistration: @@ -51,16 +53,37 @@ class TestModelRegistration: _ = await models_impl.register_model( model_id="custom-model", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + "skip_load": True, + }, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(AssertionError) as exc_info: await models_impl.register_model( model_id="custom-model-2", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + }, provider_model_id="custom-model", ) + @pytest.mark.asyncio + async def test_initialize_model_during_registering(self, inference_stack): + _, models_impl = inference_stack + + with patch( + "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", + new_callable=AsyncMock, + ) as mock_load_model: + _ = await models_impl.register_model( + model_id="Llama3.1-8B-Instruct", + metadata={ + "llama_model": "meta-llama/Llama-3.1-8B-Instruct", + }, + ) + mock_load_model.assert_called_once() + @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index 4bdde7aa6..113c3a793 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -3,10 +3,17 @@ image_name: experimental-post-training docker_image: null conda_env: experimental-post-training apis: +- inference - telemetry - datasetio - post_training providers: + inference: + - provider_id: meta-reference-inference + provider_type: inline::meta-reference + config: + max_seq_len: 4096 + checkpoint_dir: null datasetio: - provider_id: huggingface-0 provider_type: remote::huggingface @@ -24,11 +31,7 @@ metadata_store: namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -models: -- metadata: {} - model_id: ${env.POST_TRAINING_MODEL} - provider_id: meta-reference-inference - provider_model_id: null +models: [] shields: [] memory_banks: [] datasets: