diff --git a/src/llama_stack/providers/utils/inference/model_registry.py b/src/llama_stack/providers/utils/inference/model_registry.py index d60d00f87..8a120b698 100644 --- a/src/llama_stack/providers/utils/inference/model_registry.py +++ b/src/llama_stack/providers/utils/inference/model_registry.py @@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): - allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default + allowed_models: list[str] | None = Field( default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 941772b0f..09059da09 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} - # List of allowed models for this provider, if empty all models allowed - allowed_models: list[str] = [] - # Optional field name in provider data to look for API key, which takes precedence provider_data_api_key_field: str | None = None @@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): for provider_model_id in provider_models_ids: if not isinstance(provider_model_id, str): raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string") - if self.allowed_models and provider_model_id not in self.allowed_models: + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue model = self.construct_model_from_identifier(provider_model_id) diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index d98c096aa..0b5ea078b 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -455,8 +455,8 @@ class TestOpenAIMixinAllowedModels: """Test cases for allowed_models filtering functionality""" async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): - """Test that list_models filters models based on allowed_models set""" - mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"} + """Test that list_models filters models based on allowed_models""" + mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"] with mock_client_context(mixin, mock_client_with_models): result = await mixin.list_models() @@ -470,8 +470,18 @@ class TestOpenAIMixinAllowedModels: assert "final-mock-model-id" not in model_ids async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context): - """Test that empty allowed_models set allows all models""" - assert len(mixin.allowed_models) == 0 + """Test that empty allowed_models allows no models""" + mixin.config.allowed_models = [] + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 # No models should be included + + async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context): + """Test that omitted allowed_models allows all models""" + assert mixin.config.allowed_models is None with mock_client_context(mixin, mock_client_with_models): result = await mixin.list_models() @@ -488,7 +498,7 @@ class TestOpenAIMixinAllowedModels: self, mixin, mock_client_with_models, mock_client_context ): """Test that check_model_availability respects allowed_models""" - mixin.allowed_models = {"final-mock-model-id"} + mixin.config.allowed_models = ["final-mock-model-id"] with mock_client_context(mixin, mock_client_with_models): assert await mixin.check_model_availability("final-mock-model-id") @@ -536,7 +546,7 @@ class TestOpenAIMixinModelRegistration: async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): """Test model registration with allowed_models filtering""" - mixin.allowed_models = {"some-mock-model-id"} + mixin.config.allowed_models = ["some-mock-model-id"] # Test with allowed model allowed_model = Model( @@ -690,7 +700,7 @@ class TestOpenAIMixinCustomListProviderModelIds: mixin = CustomListProviderModelIdsImplementation( config=config, custom_model_ids=["model-1", "model-2", "model-3"] ) - mixin.allowed_models = ["model-1"] + mixin.config.allowed_models = ["model-1"] result = await mixin.list_models()