mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
temp commit
This commit is contained in:
parent
81e1957446
commit
415b8f2dbd
3 changed files with 28 additions and 5 deletions
|
@ -119,7 +119,7 @@ class Llama:
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||||
ckpt_dir = config.checkpoint_dir
|
ckpt_dir = config.checkpoint_dir
|
||||||
else:
|
else:
|
||||||
ckpt_dir = model_checkpoint_dir(model_id) # true model id
|
ckpt_dir = model_checkpoint_dir(model_id)
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
|
|
@ -85,7 +85,7 @@ class MetaReferenceInferenceImpl(
|
||||||
else resolve_model(model.identifier)
|
else resolve_model(model.identifier)
|
||||||
)
|
)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
raise RuntimeError(
|
raise ValueError(
|
||||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,16 +53,37 @@ class TestModelRegistration:
|
||||||
|
|
||||||
_ = await models_impl.register_model(
|
_ = await models_impl.register_model(
|
||||||
model_id="custom-model",
|
model_id="custom-model",
|
||||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
metadata={
|
||||||
|
"llama_model": "meta-llama/Llama-2-7b",
|
||||||
|
"skip_initialize": True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
await models_impl.register_model(
|
await models_impl.register_model(
|
||||||
model_id="custom-model-2",
|
model_id="custom-model-2",
|
||||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
metadata={
|
||||||
|
"llama_model": "meta-llama/Llama-2-7b",
|
||||||
|
},
|
||||||
provider_model_id="custom-model",
|
provider_model_id="custom-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_model_during_registering(self, inference_stack):
|
||||||
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_initialize:
|
||||||
|
_ = await models_impl.register_model(
|
||||||
|
model_id="Llama3.1-8B-Instruct",
|
||||||
|
metadata={
|
||||||
|
"llama_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_initialize.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue