From 0000e1e8c6f55e693caa7ea345a4823f6bc085a4 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 18 Dec 2024 14:12:57 -0800 Subject: [PATCH] address comments --- .../inline/inference/meta_reference/inference.py | 6 +++--- .../providers/tests/inference/test_model_registration.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 6ab78357a..f2354aebb 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -73,7 +73,7 @@ class MetaReferenceInferenceImpl( self.model_id = None self.llama_model = None - async def initialize(self, model_id, llama_model) -> None: + async def load_model(self, model_id, llama_model) -> None: log.info(f"Loading model `{model_id}`") if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( @@ -127,9 +127,9 @@ class MetaReferenceInferenceImpl( if model.model_type == ModelType.embedding: self._load_sentence_transformer_model(model.provider_resource_id) - if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]: + if "skip_load" in model.metadata and model.metadata["skip_load"]: return model - await self.initialize(model.identifier, llama_model) + await self.load_model(model.identifier, llama_model) return model async def completion( diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 6ac09601c..3cd7b2496 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -55,7 +55,7 @@ class TestModelRegistration: model_id="custom-model", metadata={ "llama_model": "meta-llama/Llama-2-7b", - "skip_initialize": True, + "skip_load": True, }, ) @@ -73,16 +73,16 @@ class TestModelRegistration: _, models_impl = inference_stack with patch( - "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize", + "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model", new_callable=AsyncMock, - ) as mock_initialize: + ) as mock_load_model: _ = await models_impl.register_model( model_id="Llama3.1-8B-Instruct", metadata={ "llama_model": "meta-llama/Llama-3.1-8B-Instruct", }, ) - mock_initialize.assert_called_once() + mock_load_model.assert_called_once() @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack):