temp commit

This commit is contained in:
Botao Chen 2024-12-16 22:39:08 -08:00
parent 81e1957446
commit 415b8f2dbd
3 changed files with 28 additions and 5 deletions

View file

@ -119,7 +119,7 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null": if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir ckpt_dir = config.checkpoint_dir
else: else:
ckpt_dir = model_checkpoint_dir(model_id) # true model id ckpt_dir = model_checkpoint_dir(model_id)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"

View file

@ -85,7 +85,7 @@ class MetaReferenceInferenceImpl(
else resolve_model(model.identifier) else resolve_model(model.identifier)
) )
if llama_model is None: if llama_model is None:
raise RuntimeError( raise ValueError(
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
) )

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -51,16 +53,37 @@ class TestModelRegistration:
_ = await models_impl.register_model( _ = await models_impl.register_model(
model_id="custom-model", model_id="custom-model",
metadata={"llama_model": "meta-llama/Llama-2-7b"}, metadata={
"llama_model": "meta-llama/Llama-2-7b",
"skip_initialize": True,
},
) )
with pytest.raises(ValueError) as exc_info: with pytest.raises(AssertionError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="custom-model-2", model_id="custom-model-2",
metadata={"llama_model": "meta-llama/Llama-2-7b"}, metadata={
"llama_model": "meta-llama/Llama-2-7b",
},
provider_model_id="custom-model", provider_model_id="custom-model",
) )
@pytest.mark.asyncio
async def test_initialize_model_during_registering(self, inference_stack):
_, models_impl = inference_stack
with patch(
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize",
new_callable=AsyncMock,
) as mock_initialize:
_ = await models_impl.register_model(
model_id="Llama3.1-8B-Instruct",
metadata={
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
},
)
mock_initialize.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_invalid_llama_model(self, inference_stack): async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack _, models_impl = inference_stack