fix unit tests

This commit is contained in:
Ashwin Bharambe 2025-10-15 21:49:27 -07:00
parent d8be3111db
commit ce7aa53935

View file

@ -11,6 +11,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -450,6 +451,7 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
await table.initialize() await table.initialize()
# Register model with alias (model_id different from provider_model_id) # Register model with alias (model_id different from provider_model_id)
# NOTE: Aliases are not supported anymore, so this is a no-op
await table.register_model( await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
) )
@ -458,12 +460,15 @@ async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
models = await table.list_models() models = await table.list_models()
assert len(models.data) == 1 assert len(models.data) == 1
model = models.data[0] model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier assert model.identifier == "test_provider/actual-provider-model"
assert model.provider_resource_id == "actual-provider-model" assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works # Test lookup by alias fails
retrieved_model = await table.get_model("my-alias") with pytest.raises(ModelNotFoundError, match="Model 'my-alias' not found"):
assert retrieved_model.identifier == "my-alias" await table.get_model("my-alias")
retrieved_model = await table.get_model("test_provider/actual-provider-model")
assert retrieved_model.identifier == "test_provider/actual-provider-model"
assert retrieved_model.provider_resource_id == "actual-provider-model" assert retrieved_model.provider_resource_id == "actual-provider-model"
@ -494,12 +499,8 @@ async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
assert model2.provider_resource_id == "common-model" assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error # Test lookup by unscoped provider_model_id fails with multiple providers error
try: with pytest.raises(ModelNotFoundError, match="Model 'common-model' not found"):
await table.get_model("common-model") await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
@ -522,16 +523,12 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
assert retrieved_model.identifier == "test_provider/test-model" assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration) # Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model") with pytest.raises(ModelNotFoundError, match="Model 'test-model' not found"):
assert retrieved_model.identifier == "test_provider/test-model" await table.get_model("test-model")
assert retrieved_model.provider_resource_id == "test-model"
# Test lookup of non-existent model fails # Test lookup of non-existent model fails
try: with pytest.raises(ModelNotFoundError, match="Model 'non-existent' not found"):
await table.get_model("non-existent") await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry): async def test_models_source_tracking_default(cached_disk_dist_registry):
@ -603,7 +600,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 1 assert len(models.data) == 1
user_model = models.data[0] user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias" assert user_model.identifier == "test_provider/provider-model-1"
assert user_model.provider_resource_id == "provider-model-1" assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh # Now simulate provider refresh
@ -630,7 +627,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
assert len(models.data) == 2 assert len(models.data) == 2
# Find the user model and provider model # Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) user_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-1"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None assert user_model is not None