From ce7aa53935a5355b274471339efa34eb4f68ffa1 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 15 Oct 2025 21:49:27 -0700 Subject: [PATCH] fix unit tests --- .../routers/test_routing_tables.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index beb0b4a95..87ebcef00 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,6 +11,7 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): await table.initialize() # Register model with alias (model_id different from provider_model_id) + # NOTE: Aliases are not supported anymore, so this is a no-op await table.register_model( model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" ) @@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): models = await table.list_models() assert len(models.data) == 1 model = models.data[0] - assert model.identifier == "my-alias" # Uses alias as identifier + assert model.identifier == "test_provider/actual-provider-model" assert model.provider_resource_id == "actual-provider-model" - # Test lookup by alias works - retrieved_model = await table.get_model("my-alias") - assert retrieved_model.identifier == "my-alias" + # Test lookup by alias fails + with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"): + await table.get_model("my-alias") + + retrieved_model = await table.get_model("test_provider/actual-provider-model") + assert retrieved_model.identifier == "test_provider/actual-provider-model" assert retrieved_model.provider_resource_id == "actual-provider-model" @@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): assert model2.provider_resource_id == "common-model" # Test lookup by unscoped provider_model_id fails with multiple providers error - try: + with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"): await table.get_model("common-model") - raise AssertionError("Should have raised ValueError for multiple providers") - except ValueError as e: - assert "Multiple providers found" in str(e) - assert "provider1" in str(e) and "provider2" in str(e) async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): @@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): assert retrieved_model.identifier == "test_provider/test-model" # Test lookup by unscoped provider_model_id (fallback via iteration) - retrieved_model = await table.get_model("test-model") - assert retrieved_model.identifier == "test_provider/test-model" - assert retrieved_model.provider_resource_id == "test-model" + with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"): + await table.get_model("test-model") # Test lookup of non-existent model fails - try: + with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"): await table.get_model("non-existent") - raise AssertionError("Should have raised ValueError for non-existent model") - except ValueError as e: - assert "not found" in str(e) async def test_models_source_tracking_default(cached_disk_dist_registry): @@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi assert len(models.data) == 1 user_model = models.data[0] assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.identifier == "my-custom-alias" + assert user_model.identifier == "test_provider/provider-model-1" assert user_model.provider_resource_id == "provider-model-1" # Now simulate provider refresh @@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi assert len(models.data) == 2 # Find the user model and provider model - user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) assert user_model is not None