diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index d60d00f87..8a120b698 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): - allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default + allowed_models: list[str] | None = Field( default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 0a283780f..a12b506db 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -82,9 +82,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} - # List of allowed models for this provider, if empty all models allowed - allowed_models: list[str] = [] - # Optional field name in provider data to look for API key, which takes precedence provider_data_api_key_field: str | None = None @@ -416,7 +413,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): for provider_model_id in provider_models_ids: if not isinstance(provider_model_id, str): raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string") - if self.allowed_models and provider_model_id not in self.allowed_models: + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue if metadata := self.embedding_model_metadata.get(provider_model_id): diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 61a1f8f61..540e8746b 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -363,8 +363,8 @@ class TestOpenAIMixinAllowedModels: """Test cases for allowed_models filtering functionality""" async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): - """Test that list_models filters models based on allowed_models set""" - mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"} + """Test that list_models filters models based on allowed_models""" + mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"] with mock_client_context(mixin, mock_client_with_models): result = await mixin.list_models() @@ -378,8 +378,18 @@ class TestOpenAIMixinAllowedModels: assert "final-mock-model-id" not in model_ids async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context): - """Test that empty allowed_models set allows all models""" - assert len(mixin.allowed_models) == 0 + """Test that empty allowed_models allows no models""" + mixin.config.allowed_models = [] + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 # No models should be included + + async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context): + """Test that omitted allowed_models allows all models""" + assert mixin.config.allowed_models is None with mock_client_context(mixin, mock_client_with_models): result = await mixin.list_models() @@ -396,7 +406,7 @@ class TestOpenAIMixinAllowedModels: self, mixin, mock_client_with_models, mock_client_context ): """Test that check_model_availability respects allowed_models""" - mixin.allowed_models = {"final-mock-model-id"} + mixin.config.allowed_models = ["final-mock-model-id"] with mock_client_context(mixin, mock_client_with_models): assert await mixin.check_model_availability("final-mock-model-id") @@ -444,7 +454,7 @@ class TestOpenAIMixinModelRegistration: async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): """Test model registration with allowed_models filtering""" - mixin.allowed_models = {"some-mock-model-id"} + mixin.config.allowed_models = ["some-mock-model-id"] # Test with allowed model allowed_model = Model( @@ -598,7 +608,7 @@ class TestOpenAIMixinCustomListProviderModelIds: mixin = CustomListProviderModelIdsImplementation( config=config, custom_model_ids=["model-1", "model-2", "model-3"] ) - mixin.allowed_models = ["model-1"] + mixin.config.allowed_models = ["model-1"] result = await mixin.list_models()