diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 69d7e9b6f..716be936a 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] + async def has_model(self, model_id: str) -> bool: + """ + Check if a model exists in the routing table. + + :param model_id: The model identifier to check + :return: True if the model exists, False otherwise + """ + try: + await lookup_model(self, model_id) + return True + except ModelNotFoundError: + return False + async def register_model( self, model_id: str, diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 3c5c5b4de..cba7508a2 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -474,11 +474,17 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): async def check_model_availability(self, model: str) -> bool: """ - Check if a specific model is available from the provider's /v1/models. + Check if a specific model is available from the provider's /v1/models or pre-registered. :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: True if the model is available dynamically or pre-registered, False otherwise. """ + # First check if the model is pre-registered in the model store + if hasattr(self, "model_store") and self.model_store: + if await self.model_store.has_model(model): + return True + + # Then check the provider's dynamic model cache if not self._model_cache: await self.list_models() return model in self._model_cache diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e..a1c3d1e95 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry): non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None + # Test has_model + assert await table.has_model("test_provider/test-model") + assert await table.has_model("test_provider/test-model-2") + assert not await table.has_model("non-existent-model") + assert not await table.has_model("test_provider/non-existent-model") + await table.unregister_model(model_id="test_provider/test-model") await table.unregister_model(model_id="test_provider/test-model-2") diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 2e3a62ca6..ad9406951 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -44,11 +44,12 @@ def mixin(): config = RemoteInferenceProviderConfig() mixin_instance = OpenAIMixinImpl(config=config) - # just enough to satisfy _get_provider_model_id calls - mock_model_store = MagicMock() + # Mock model_store with async methods + mock_model_store = AsyncMock() mock_model = MagicMock() mock_model.provider_resource_id = "test-provider-resource-id" mock_model_store.get_model = AsyncMock(return_value=mock_model) + mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override mixin_instance.model_store = mock_model_store return mixin_instance @@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability: assert len(mixin._model_cache) == 3 + async def test_check_model_availability_with_pre_registered_model( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability returns True for pre-registered models in model_store""" + # Mock model_store.has_model to return True for a specific model + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=True) + mixin.model_store = mock_model_store + + # Test that pre-registered model is found without calling the provider's API + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("pre-registered-model") + # Should not call the provider's list_models since model was found in store + mock_client_with_models.models.list.assert_not_called() + mock_model_store.has_model.assert_called_once_with("pre-registered-model") + + async def test_check_model_availability_fallback_to_provider_when_not_in_store( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability falls back to provider when model not in store""" + # Mock model_store.has_model to return False + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=False) + mixin.model_store = mock_model_store + + # Test that it falls back to provider's model cache + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + # Should call the provider's list_models since model was not found in store + mock_client_with_models.models.list.assert_called_once() + mock_model_store.has_model.assert_called_once_with("some-mock-model-id") + class TestOpenAIMixinCacheBehavior: """Test cases for cache behavior and edge cases"""