From f216eb99be610deb883f2cb3b801229876ba5578 Mon Sep 17 00:00:00 2001
From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com>
Date: Mon, 24 Nov 2025 11:29:53 -0800
Subject: [PATCH] fix: allowed_models config did not filter models (backport
#4030) (#4223)
# What does this PR do?
closes #4022
## Test Plan
ci w/ new tests
This is an automatic backport of pull request #4030
done by [Mergify](https://mergify.com).
Co-authored-by: Matthew Farrellee
Co-authored-by: Ashwin Bharambe
---
.../utils/inference/model_registry.py | 2 +-
.../providers/utils/inference/openai_mixin.py | 5 +---
.../utils/inference/test_openai_mixin.py | 24 +++++++++++++------
3 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py
index d60d00f87..8a120b698 100644
--- a/llama_stack/providers/utils/inference/model_registry.py
+++ b/llama_stack/providers/utils/inference/model_registry.py
@@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):
- allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
+ allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py
index 0a283780f..a12b506db 100644
--- a/llama_stack/providers/utils/inference/openai_mixin.py
+++ b/llama_stack/providers/utils/inference/openai_mixin.py
@@ -82,9 +82,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
- # List of allowed models for this provider, if empty all models allowed
- allowed_models: list[str] = []
-
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@@ -416,7 +413,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
- if self.allowed_models and provider_model_id not in self.allowed_models:
+ if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(provider_model_id):
diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py
index 61a1f8f61..540e8746b 100644
--- a/tests/unit/providers/utils/inference/test_openai_mixin.py
+++ b/tests/unit/providers/utils/inference/test_openai_mixin.py
@@ -363,8 +363,8 @@ class TestOpenAIMixinAllowedModels:
"""Test cases for allowed_models filtering functionality"""
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
- """Test that list_models filters models based on allowed_models set"""
- mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
+ """Test that list_models filters models based on allowed_models"""
+ mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"]
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
@@ -378,8 +378,18 @@ class TestOpenAIMixinAllowedModels:
assert "final-mock-model-id" not in model_ids
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
- """Test that empty allowed_models set allows all models"""
- assert len(mixin.allowed_models) == 0
+ """Test that empty allowed_models allows no models"""
+ mixin.config.allowed_models = []
+
+ with mock_client_context(mixin, mock_client_with_models):
+ result = await mixin.list_models()
+
+ assert result is not None
+ assert len(result) == 0 # No models should be included
+
+ async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
+ """Test that omitted allowed_models allows all models"""
+ assert mixin.config.allowed_models is None
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.list_models()
@@ -396,7 +406,7 @@ class TestOpenAIMixinAllowedModels:
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability respects allowed_models"""
- mixin.allowed_models = {"final-mock-model-id"}
+ mixin.config.allowed_models = ["final-mock-model-id"]
with mock_client_context(mixin, mock_client_with_models):
assert await mixin.check_model_availability("final-mock-model-id")
@@ -444,7 +454,7 @@ class TestOpenAIMixinModelRegistration:
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration with allowed_models filtering"""
- mixin.allowed_models = {"some-mock-model-id"}
+ mixin.config.allowed_models = ["some-mock-model-id"]
# Test with allowed model
allowed_model = Model(
@@ -598,7 +608,7 @@ class TestOpenAIMixinCustomListProviderModelIds:
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
)
- mixin.allowed_models = ["model-1"]
+ mixin.config.allowed_models = ["model-1"]
result = await mixin.list_models()