fix tests

This commit is contained in:
Dinesh Yeduguru 2024-11-13 21:54:43 -08:00
parent 43af05d851
commit 05535698e2

View file

@ -6,6 +6,8 @@
import pytest
from llama_models.datatypes import CoreModelId
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
@ -39,16 +41,17 @@ class TestModelRegistration:
_, models_impl = inference_stack
# Register a model to update
model_id = "Llama3.1-8B-Instruct"
await models_impl.register_model(model_id=model_id)
model_id = CoreModelId.llama3_1_8b_instruct.value
old_model = await models_impl.register_model(model_id=model_id)
# Update the model
new_provider_id = "updated_provider"
await models_impl.update_model(model_id=model_id, provider_id=new_provider_id)
new_model_id = CoreModelId.llama3_2_3b_instruct.value
updated_model = await models_impl.update_model(
model_id=model_id, provider_model_id=new_model_id
)
# Retrieve the updated model to verify changes
updated_model = await models_impl.get_model(model_id)
assert updated_model.provider_id == new_provider_id
assert updated_model.provider_resource_id != old_model.provider_resource_id
# Cleanup
await models_impl.delete_model(model_id=model_id)