diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index f4f28c1f3..bb6fe162f 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -82,35 +82,35 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_llama_model(self, provider_model_id: str) -> str | None: return self.provider_id_to_llama_model_map.get(provider_model_id, None) - async def query_available_models(self) -> list[str]: + async def check_model_availability(self, model: str) -> bool: """ - Return a list of available models. + Check if a specific model is available from the provider (non-static check). - This is for subclassing purposes, so providers can lookup a list of - of currently available models. + This is for subclassing purposes, so providers can check if a specific + model is currently available for use through dynamic means (e.g., API calls). - This is combined with the statically configured model entries in - `self.alias_to_provider_id_map` to determine which models are - available for registration. + This method should NOT check statically configured model entries in + `self.alias_to_provider_id_map` - that is handled separately in register_model. - Default implementation returns no models. + Default implementation returns False (no dynamic models available). - :return: A list of model identifiers (provider_model_ids). + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. """ - return [] + return False async def register_model(self, model: Model) -> Model: # Check if model is supported in static configuration supported_model_id = self.get_provider_model_id(model.provider_resource_id) - # If not found in static config, check if it's available from provider + # If not found in static config, check if it's available dynamically from provider if not supported_model_id: - available_models = await self.query_available_models() - if model.provider_resource_id in available_models: + if await self.check_model_availability(model.provider_resource_id): supported_model_id = model.provider_resource_id else: - # Combine static and dynamic models for error message - all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models + # note: we cannot provide a complete list of supported models without + # getting a complete list from the provider, so we return "..." + all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."] raise UnsupportedModelError(model.provider_resource_id, all_supported_models) provider_resource_id = self.get_provider_model_id(model.model_id) diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index c7f7eb299..768c3572b 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -94,8 +94,8 @@ class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper): super().__init__(model_entries) self._available_models = available_models - async def query_available_models(self) -> list[str]: - return self._available_models + async def check_model_availability(self, model: str) -> bool: + return model in self._available_models @pytest.fixture @@ -195,16 +195,16 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m @pytest.mark.asyncio -async def test_register_model_from_query_available_models( +async def test_register_model_from_check_model_availability( helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model ) -> None: - """Test that models returned by query_available_models can be registered.""" + """Test that models returned by check_model_availability can be registered.""" # Verify the model is not in static config assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None - # But it should be available via query_available_models - available_models = await helper_with_dynamic_models.query_available_models() - assert dynamic_model.provider_resource_id in available_models + # But it should be available via check_model_availability + is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id) + assert is_available # Registration should succeed registered_model = await helper_with_dynamic_models.register_model(dynamic_model) @@ -224,17 +224,17 @@ async def test_register_model_not_in_static_or_dynamic( # Verify the model is not in static config assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None - # And not in dynamic models - available_models = await helper_with_dynamic_models.query_available_models() - assert unknown_model.provider_resource_id not in available_models + # And not available via check_model_availability + is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id) + assert not is_available # Registration should fail with comprehensive error message with pytest.raises(Exception) as exc_info: # UnsupportedModelError await helper_with_dynamic_models.register_model(unknown_model) - # Error should include both static and dynamic models + # Error should include static models and "..." for dynamic models error_str = str(exc_info.value) - assert "dynamic-provider-id" in error_str # dynamic model should be in error + assert "..." in error_str # "..." should be in error message @pytest.mark.asyncio