address comments

This commit is contained in:
Botao Chen 2024-12-18 14:12:57 -08:00
parent d021983b0e
commit 0000e1e8c6
2 changed files with 7 additions and 7 deletions

View file

@ -73,7 +73,7 @@ class MetaReferenceInferenceImpl(
self.model_id = None self.model_id = None
self.llama_model = None self.llama_model = None
async def initialize(self, model_id, llama_model) -> None: async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
@ -127,9 +127,9 @@ class MetaReferenceInferenceImpl(
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]: if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model return model
await self.initialize(model.identifier, llama_model) await self.load_model(model.identifier, llama_model)
return model return model
async def completion( async def completion(

View file

@ -55,7 +55,7 @@ class TestModelRegistration:
model_id="custom-model", model_id="custom-model",
metadata={ metadata={
"llama_model": "meta-llama/Llama-2-7b", "llama_model": "meta-llama/Llama-2-7b",
"skip_initialize": True, "skip_load": True,
}, },
) )
@ -73,16 +73,16 @@ class TestModelRegistration:
_, models_impl = inference_stack _, models_impl = inference_stack
with patch( with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize", "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
new_callable=AsyncMock, new_callable=AsyncMock,
) as mock_initialize: ) as mock_load_model:
_ = await models_impl.register_model( _ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct", model_id="Llama3.1-8B-Instruct",
metadata={ metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct", "llama_model": "meta-llama/Llama-3.1-8B-Instruct",
}, },
) )
mock_initialize.assert_called_once() mock_load_model.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack): async def test_register_with_invalid_llama_model(self, inference_stack):