From ccb5445d2aeba47493cd93ece40bc5432f871338 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 11:58:32 -0800 Subject: [PATCH] Allow models to be registered as long as llama model is provided --- .../inference/test_model_registration.py | 33 +++++++++---------- .../utils/inference/model_registry.py | 18 +++++++--- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 0f07badfa..72b55ac1c 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -8,6 +8,7 @@ import pytest from llama_models.datatypes import CoreModelId + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py @@ -17,8 +18,17 @@ from llama_models.datatypes import CoreModelId class TestModelRegistration: @pytest.mark.asyncio - async def test_register_unsupported_model(self, inference_stack): - _, models_impl = inference_stack + async def test_register_unsupported_model(self, inference_stack, inference_model): + inference_impl, models_impl = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::ollama", + "remote::vllm", + "remote::tgi", + ): + pytest.skip("70B instruct is too big only for local inference providers") # Try to register a model that's too large for local inference with pytest.raises(Exception) as exc_info: @@ -37,21 +47,10 @@ class TestModelRegistration: ) @pytest.mark.asyncio - async def test_update_model(self, inference_stack): + async def test_register_with_llama_model(self, inference_stack): _, models_impl = inference_stack - # Register a model to update - model_id = CoreModelId.llama3_1_8b_instruct.value - old_model = await models_impl.register_model(model_id=model_id) - - # Update the model - new_model_id = CoreModelId.llama3_2_3b_instruct.value - updated_model = await models_impl.update_model( - model_id=model_id, provider_model_id=new_model_id + _ = await models_impl.register_model( + model_id="custom-model", + metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value}, ) - - # Retrieve the updated model to verify changes - assert updated_model.provider_resource_id != old_model.provider_resource_id - - # Cleanup - await models_impl.unregister_model(model_id=model_id) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 77eb5b415..ab036e7e2 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -51,7 +51,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): if identifier in self.alias_to_provider_id_map: return self.alias_to_provider_id_map[identifier] else: - raise ValueError(f"Unknown model: `{identifier}`") + return None def get_llama_model(self, provider_model_id: str) -> str: if provider_model_id in self.provider_id_to_llama_model_map: @@ -60,8 +60,18 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return None async def register_model(self, model: Model) -> Model: - model.provider_resource_id = self.get_provider_model_id( - model.provider_resource_id - ) + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + if model.metadata.get("llama_model") is None: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " + "Please specify a llama_model in metadata or use a supported model identifier" + ) + # Register the mapping from provider model id to llama model for future lookups + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + model.metadata["llama_model"] + ) return model