mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: allow user to register model alias explicitly, tests
# What does this PR do? Context: https://github.com/llamastack/llama-stack/discussions/3483 This PR enables the registering `provider_model_id` as the model identifier without breaking backward compatibility. ## Test Plan todo # What does this PR do? ## Test Plan
This commit is contained in:
parent
ac1414b571
commit
83a229554b
20 changed files with 236 additions and 92 deletions
|
@ -645,3 +645,88 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
|||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_register_with_use_provider_model_id_as_id(cached_disk_dist_registry):
|
||||
"""Test register_model with the new use_provider_model_id_as_id parameter."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register model using use_provider_model_id_as_id parameter
|
||||
await table.register_model(
|
||||
provider_model_id="actual-provider-model", provider_id="test_provider", use_provider_model_id_as_id=True
|
||||
)
|
||||
|
||||
# Verify the model was registered with provider_model_id as identifier
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.identifier == "actual-provider-model"
|
||||
assert model.provider_resource_id == "actual-provider-model"
|
||||
assert model.provider_id == "test_provider"
|
||||
|
||||
# Test lookup by provider_model_id works
|
||||
retrieved_model = await table.get_model("actual-provider-model")
|
||||
assert retrieved_model.identifier == "actual-provider-model"
|
||||
assert retrieved_model.provider_resource_id == "actual-provider-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_register_provider_model_id_only(cached_disk_dist_registry):
|
||||
"""Test register_model with only provider_model_id (new recommended usage)."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register model using only provider_model_id
|
||||
await table.register_model(provider_model_id="llama-3.1-8b", provider_id="test_provider", model_type=ModelType.llm)
|
||||
|
||||
# Verify the model was registered with namespaced identifier
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.identifier == "test_provider/llama-3.1-8b"
|
||||
assert model.provider_resource_id == "llama-3.1-8b"
|
||||
assert model.provider_id == "test_provider"
|
||||
|
||||
# Test lookup works
|
||||
retrieved_model = await table.get_model("test_provider/llama-3.1-8b")
|
||||
assert retrieved_model.identifier == "test_provider/llama-3.1-8b"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_register_validation_errors(cached_disk_dist_registry):
|
||||
"""Test register_model validation errors."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Test error when neither model_id nor provider_model_id is provided
|
||||
with pytest.raises(ValueError, match="provider_model_id must be provided"):
|
||||
await table.register_model(provider_id="test_provider")
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_register_backward_compatibility_warning(cached_disk_dist_registry):
|
||||
"""Test that register_model warns when model_id equals provider_model_id."""
|
||||
from unittest.mock import patch
|
||||
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Test warning is logged when model_id == provider_model_id
|
||||
with patch("llama_stack.core.routing_tables.models.logger") as mock_logger:
|
||||
await table.register_model(model_id="same-model", provider_model_id="same-model", provider_id="test_provider")
|
||||
|
||||
# Verify warning was called
|
||||
mock_logger.warning.assert_called_once()
|
||||
warning_msg = mock_logger.warning.call_args[0][0]
|
||||
assert "model_id` is now optional" in warning_msg
|
||||
assert "provider_model_id='same-model'" in warning_msg
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue