mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
address comments
This commit is contained in:
parent
d021983b0e
commit
0000e1e8c6
2 changed files with 7 additions and 7 deletions
|
@ -73,7 +73,7 @@ class MetaReferenceInferenceImpl(
|
||||||
self.model_id = None
|
self.model_id = None
|
||||||
self.llama_model = None
|
self.llama_model = None
|
||||||
|
|
||||||
async def initialize(self, model_id, llama_model) -> None:
|
async def load_model(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
log.info(f"Loading model `{model_id}`")
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
|
@ -127,9 +127,9 @@ class MetaReferenceInferenceImpl(
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
|
||||||
if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]:
|
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||||
return model
|
return model
|
||||||
await self.initialize(model.identifier, llama_model)
|
await self.load_model(model.identifier, llama_model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -55,7 +55,7 @@ class TestModelRegistration:
|
||||||
model_id="custom-model",
|
model_id="custom-model",
|
||||||
metadata={
|
metadata={
|
||||||
"llama_model": "meta-llama/Llama-2-7b",
|
"llama_model": "meta-llama/Llama-2-7b",
|
||||||
"skip_initialize": True,
|
"skip_load": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,16 +73,16 @@ class TestModelRegistration:
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize",
|
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.load_model",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
) as mock_initialize:
|
) as mock_load_model:
|
||||||
_ = await models_impl.register_model(
|
_ = await models_impl.register_model(
|
||||||
model_id="Llama3.1-8B-Instruct",
|
model_id="Llama3.1-8B-Instruct",
|
||||||
metadata={
|
metadata={
|
||||||
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mock_initialize.assert_called_once()
|
mock_load_model.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue