From acf9af841be733a1da6e3a2365269dbfd7a0e13a Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 12:11:53 -0800 Subject: [PATCH] more validation --- .../inference/test_model_registration.py | 17 +++++++++++++ .../utils/inference/model_registry.py | 24 +++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 72b55ac1c..4fa7f4e61 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -54,3 +54,20 @@ class TestModelRegistration: model_id="custom-model", metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value}, ) + + with pytest.raises(ValueError) as exc_info: + await models_impl.register_model( + model_id="custom-model-2", + metadata={"llama_model": CoreModelId.llama3_2_3b_instruct.value}, + provider_model_id="custom-model", + ) + + @pytest.mark.asyncio + async def test_register_with_invalid_llama_model(self, inference_stack): + _, models_impl = inference_stack + + with pytest.raises(Exception) as exc_info: + await models_impl.register_model( + model_id="custom-model-2", + metadata={"llama_model": "invalid-llama-model"}, + ) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ab036e7e2..98810783d 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,6 +7,7 @@ from collections import namedtuple from typing import List, Optional +from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -69,9 +70,24 @@ class ModelRegistryHelper(ModelsProtocolPrivate): f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " "Please specify a llama_model in metadata or use a supported model identifier" ) - # Register the mapping from provider model id to llama model for future lookups - self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - model.metadata["llama_model"] - ) + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != model.metadata["llama_model"]: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + # Validate that the llama model is a valid CoreModelId + try: + CoreModelId(model.metadata["llama_model"]) + except ValueError as err: + raise ValueError( + f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " + f"Must be one of: {', '.join(m.value for m in CoreModelId)}" + ) from err + # Register the mapping from provider model id to llama model for future lookups + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + model.metadata["llama_model"] + ) return model