diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 46c0ca7b5..f4f28c1f3 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -82,9 +82,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_llama_model(self, provider_model_id: str) -> str | None: return self.provider_id_to_llama_model_map.get(provider_model_id, None) + async def query_available_models(self) -> list[str]: + """ + Return a list of available models. + + This is for subclassing purposes, so providers can lookup a list of + of currently available models. + + This is combined with the statically configured model entries in + `self.alias_to_provider_id_map` to determine which models are + available for registration. + + Default implementation returns no models. + + :return: A list of model identifiers (provider_model_ids). + """ + return [] + async def register_model(self, model: Model) -> Model: - if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) + # Check if model is supported in static configuration + supported_model_id = self.get_provider_model_id(model.provider_resource_id) + + # If not found in static config, check if it's available from provider + if not supported_model_id: + available_models = await self.query_available_models() + if model.provider_resource_id in available_models: + supported_model_id = model.provider_resource_id + else: + # Combine static and dynamic models for error message + all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models + raise UnsupportedModelError(model.provider_resource_id, all_supported_models) + provider_resource_id = self.get_provider_model_id(model.model_id) if model.model_type == ModelType.embedding: # embedding models are always registered by their provider model id and does not need to be mapped to a llama model @@ -113,6 +141,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) + # Register the model alias, ensuring it maps to the correct provider model id self.alias_to_provider_id_map[model.model_id] = supported_model_id return model diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index 10fa1e075..c7f7eb299 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov return ModelRegistryHelper([known_provider_model, known_provider_model2]) +class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper): + """Test helper that simulates a provider with dynamically available models.""" + + def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]): + super().__init__(model_entries) + self._available_models = available_models + + async def query_available_models(self) -> list[str]: + return self._available_models + + +@pytest.fixture +def dynamic_model() -> Model: + """A model that's not in static config but available dynamically.""" + return Model( + provider_id="provider", + identifier="dynamic-model", + provider_resource_id="dynamic-provider-id", + ) + + +@pytest.fixture +def helper_with_dynamic_models( + known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model +) -> MockModelRegistryHelperWithDynamicModels: + """Helper that includes dynamically available models.""" + return MockModelRegistryHelperWithDynamicModels( + [known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id] + ) + + @pytest.mark.asyncio async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: assert helper.get_provider_model_id(unknown_model.model_id) is None @@ -161,3 +192,66 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id await helper.unregister_model(known_model.provider_resource_id) assert helper.get_provider_model_id(known_model.provider_resource_id) is None + + +@pytest.mark.asyncio +async def test_register_model_from_query_available_models( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that models returned by query_available_models can be registered.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None + + # But it should be available via query_available_models + available_models = await helper_with_dynamic_models.query_available_models() + assert dynamic_model.provider_resource_id in available_models + + # Registration should succeed + registered_model = await helper_with_dynamic_models.register_model(dynamic_model) + assert registered_model == dynamic_model + + # Model should now be registered and accessible + assert ( + helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id + ) + + +@pytest.mark.asyncio +async def test_register_model_not_in_static_or_dynamic( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model +) -> None: + """Test that models not in static config or dynamic models are rejected.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None + + # And not in dynamic models + available_models = await helper_with_dynamic_models.query_available_models() + assert unknown_model.provider_resource_id not in available_models + + # Registration should fail with comprehensive error message + with pytest.raises(Exception) as exc_info: # UnsupportedModelError + await helper_with_dynamic_models.register_model(unknown_model) + + # Error should include both static and dynamic models + error_str = str(exc_info.value) + assert "dynamic-provider-id" in error_str # dynamic model should be in error + + +@pytest.mark.asyncio +async def test_register_alias_for_dynamic_model( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that we can register an alias that maps to a dynamically available model.""" + # Create a model with a different identifier but same provider_resource_id + alias_model = Model( + provider_id=dynamic_model.provider_id, + identifier="dynamic-model-alias", + provider_resource_id=dynamic_model.provider_resource_id, + ) + + # Registration should succeed since the provider_resource_id is available dynamically + registered_model = await helper_with_dynamic_models.register_model(alias_model) + assert registered_model == alias_model + + # Both the original provider_resource_id and the new alias should work + assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id