From 415b8f2dbda4b8140bc3c775f37dacfed64255e5 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 16 Dec 2024 22:39:08 -0800 Subject: [PATCH] temp commit --- .../inference/meta_reference/generation.py | 2 +- .../inference/meta_reference/inference.py | 2 +- .../inference/test_model_registration.py | 29 +++++++++++++++++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 9bb1bbdaf..e11f33503 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -119,7 +119,7 @@ class Llama: if config.checkpoint_dir and config.checkpoint_dir != "null": ckpt_dir = config.checkpoint_dir else: - ckpt_dir = model_checkpoint_dir(model_id) # true model id + ckpt_dir = model_checkpoint_dir(model_id) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 930032ceb..ef3f92bb5 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -85,7 +85,7 @@ class MetaReferenceInferenceImpl( else resolve_model(model.identifier) ) if llama_model is None: - raise RuntimeError( + raise ValueError( "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" ) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 74076fb28..6ac09601c 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from unittest.mock import AsyncMock, patch + import pytest @@ -51,16 +53,37 @@ class TestModelRegistration: _ = await models_impl.register_model( model_id="custom-model", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + "skip_initialize": True, + }, ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(AssertionError) as exc_info: await models_impl.register_model( model_id="custom-model-2", - metadata={"llama_model": "meta-llama/Llama-2-7b"}, + metadata={ + "llama_model": "meta-llama/Llama-2-7b", + }, provider_model_id="custom-model", ) + @pytest.mark.asyncio + async def test_initialize_model_during_registering(self, inference_stack): + _, models_impl = inference_stack + + with patch( + "llama_stack.providers.inline.inference.meta_reference.inference.MetaReferenceInferenceImpl.initialize", + new_callable=AsyncMock, + ) as mock_initialize: + _ = await models_impl.register_model( + model_id="Llama3.1-8B-Instruct", + metadata={ + "llama_model": "meta-llama/Llama-3.1-8B-Instruct", + }, + ) + mock_initialize.assert_called_once() + @pytest.mark.asyncio async def test_register_with_invalid_llama_model(self, inference_stack): _, models_impl = inference_stack