diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 4b20e519c..f2fa4bb9a 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -33,3 +33,22 @@ class TestModelRegistration: await models_impl.register_model( model_id="Llama3-NonExistent-Model", ) + + @pytest.mark.asyncio + async def test_update_model(self, inference_stack): + _, models_impl = inference_stack + + # Register a model to update + model_id = "Llama3.1-8B-Instruct" + await models_impl.register_model(model_id=model_id) + + # Update the model + new_provider_id = "updated_provider" + await models_impl.update_model(model_id=model_id, provider_id=new_provider_id) + + # Retrieve the updated model to verify changes + updated_model = await models_impl.get_model(model_id) + assert updated_model.provider_id == new_provider_id + + # Cleanup + await models_impl.delete_model(model_id=model_id)