forked from phoenix-oss/llama-stack-mirror
Allow models to be registered as long as llama model is provided (#472)
This PR allows models to be registered with provider as long as the user specifies a llama model, even though the model does not match our prebuilt provider specific mapping. Test: pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py -m "together" --env TOGETHER_API_KEY=<KEY> --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
2a31163178
commit
57a9b4d57f
3 changed files with 72 additions and 21 deletions
|
@ -6,7 +6,6 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -17,11 +16,22 @@ from llama_models.datatypes import CoreModelId
|
|||
|
||||
class TestModelRegistration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unsupported_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
||||
inference_impl, models_impl = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::vllm",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip(
|
||||
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
|
||||
)
|
||||
|
||||
# Try to register a model that's too large for local inference
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="Llama3.1-70B-Instruct",
|
||||
)
|
||||
|
@ -37,21 +47,27 @@ class TestModelRegistration:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model(self, inference_stack):
|
||||
async def test_register_with_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
# Register a model to update
|
||||
model_id = CoreModelId.llama3_1_8b_instruct.value
|
||||
old_model = await models_impl.register_model(model_id=model_id)
|
||||
|
||||
# Update the model
|
||||
new_model_id = CoreModelId.llama3_2_3b_instruct.value
|
||||
updated_model = await models_impl.update_model(
|
||||
model_id=model_id, provider_model_id=new_model_id
|
||||
_ = await models_impl.register_model(
|
||||
model_id="custom-model",
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
)
|
||||
|
||||
# Retrieve the updated model to verify changes
|
||||
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
provider_model_id="custom-model",
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
await models_impl.unregister_model(model_id=model_id)
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={"llama_model": "invalid-llama-model"},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue