From 05535698e2eedd704a32a140f62611779ba8c3a8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 21:54:43 -0800 Subject: [PATCH] fix tests --- .../tests/inference/test_model_registration.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index f2fa4bb9a..97f0ac576 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -6,6 +6,8 @@ import pytest +from llama_models.datatypes import CoreModelId + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py @@ -39,16 +41,17 @@ class TestModelRegistration: _, models_impl = inference_stack # Register a model to update - model_id = "Llama3.1-8B-Instruct" - await models_impl.register_model(model_id=model_id) + model_id = CoreModelId.llama3_1_8b_instruct.value + old_model = await models_impl.register_model(model_id=model_id) # Update the model - new_provider_id = "updated_provider" - await models_impl.update_model(model_id=model_id, provider_id=new_provider_id) + new_model_id = CoreModelId.llama3_2_3b_instruct.value + updated_model = await models_impl.update_model( + model_id=model_id, provider_model_id=new_model_id + ) # Retrieve the updated model to verify changes - updated_model = await models_impl.get_model(model_id) - assert updated_model.provider_id == new_provider_id + assert updated_model.provider_resource_id != old_model.provider_resource_id # Cleanup await models_impl.delete_model(model_id=model_id)