Allow models to be registered as long as llama model is provided

This commit is contained in:
Dinesh Yeduguru 2024-11-18 11:58:32 -08:00
parent f1b9578f8d
commit ccb5445d2a
2 changed files with 30 additions and 21 deletions

View file

@ -8,6 +8,7 @@ import pytest
from llama_models.datatypes import CoreModelId
# How to run this test:
#
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
@ -17,8 +18,17 @@ from llama_models.datatypes import CoreModelId
class TestModelRegistration:
@pytest.mark.asyncio
async def test_register_unsupported_model(self, inference_stack):
_, models_impl = inference_stack
async def test_register_unsupported_model(self, inference_stack, inference_model):
inference_impl, models_impl = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::vllm",
"remote::tgi",
):
pytest.skip("70B instruct is too big only for local inference providers")
# Try to register a model that's too large for local inference
with pytest.raises(Exception) as exc_info:
@ -37,21 +47,10 @@ class TestModelRegistration:
)
@pytest.mark.asyncio
async def test_update_model(self, inference_stack):
async def test_register_with_llama_model(self, inference_stack):
_, models_impl = inference_stack
# Register a model to update
model_id = CoreModelId.llama3_1_8b_instruct.value
old_model = await models_impl.register_model(model_id=model_id)
# Update the model
new_model_id = CoreModelId.llama3_2_3b_instruct.value
updated_model = await models_impl.update_model(
model_id=model_id, provider_model_id=new_model_id
_ = await models_impl.register_model(
model_id="custom-model",
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value},
)
# Retrieve the updated model to verify changes
assert updated_model.provider_resource_id != old_model.provider_resource_id
# Cleanup
await models_impl.unregister_model(model_id=model_id)