mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? closes #4022 ## Test Plan ci w/ new tests<hr>This is an automatic backport of pull request #4030 done by [Mergify](https://mergify.com). Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
49a290e53e
commit
f216eb99be
3 changed files with 19 additions and 12 deletions
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
allowed_models: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -82,9 +82,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
# This is set in list_models() and used in check_model_availability()
|
# This is set in list_models() and used in check_model_availability()
|
||||||
_model_cache: dict[str, Model] = {}
|
_model_cache: dict[str, Model] = {}
|
||||||
|
|
||||||
# List of allowed models for this provider, if empty all models allowed
|
|
||||||
allowed_models: list[str] = []
|
|
||||||
|
|
||||||
# Optional field name in provider data to look for API key, which takes precedence
|
# Optional field name in provider data to look for API key, which takes precedence
|
||||||
provider_data_api_key_field: str | None = None
|
provider_data_api_key_field: str | None = None
|
||||||
|
|
||||||
|
|
@ -416,7 +413,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
for provider_model_id in provider_models_ids:
|
for provider_model_id in provider_models_ids:
|
||||||
if not isinstance(provider_model_id, str):
|
if not isinstance(provider_model_id, str):
|
||||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||||
continue
|
continue
|
||||||
if metadata := self.embedding_model_metadata.get(provider_model_id):
|
if metadata := self.embedding_model_metadata.get(provider_model_id):
|
||||||
|
|
|
||||||
|
|
@ -363,8 +363,8 @@ class TestOpenAIMixinAllowedModels:
|
||||||
"""Test cases for allowed_models filtering functionality"""
|
"""Test cases for allowed_models filtering functionality"""
|
||||||
|
|
||||||
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
"""Test that list_models filters models based on allowed_models set"""
|
"""Test that list_models filters models based on allowed_models"""
|
||||||
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
|
mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"]
|
||||||
|
|
||||||
with mock_client_context(mixin, mock_client_with_models):
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
result = await mixin.list_models()
|
result = await mixin.list_models()
|
||||||
|
|
@ -378,8 +378,18 @@ class TestOpenAIMixinAllowedModels:
|
||||||
assert "final-mock-model-id" not in model_ids
|
assert "final-mock-model-id" not in model_ids
|
||||||
|
|
||||||
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
"""Test that empty allowed_models set allows all models"""
|
"""Test that empty allowed_models allows no models"""
|
||||||
assert len(mixin.allowed_models) == 0
|
mixin.config.allowed_models = []
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
result = await mixin.list_models()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 0 # No models should be included
|
||||||
|
|
||||||
|
async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
|
"""Test that omitted allowed_models allows all models"""
|
||||||
|
assert mixin.config.allowed_models is None
|
||||||
|
|
||||||
with mock_client_context(mixin, mock_client_with_models):
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
result = await mixin.list_models()
|
result = await mixin.list_models()
|
||||||
|
|
@ -396,7 +406,7 @@ class TestOpenAIMixinAllowedModels:
|
||||||
self, mixin, mock_client_with_models, mock_client_context
|
self, mixin, mock_client_with_models, mock_client_context
|
||||||
):
|
):
|
||||||
"""Test that check_model_availability respects allowed_models"""
|
"""Test that check_model_availability respects allowed_models"""
|
||||||
mixin.allowed_models = {"final-mock-model-id"}
|
mixin.config.allowed_models = ["final-mock-model-id"]
|
||||||
|
|
||||||
with mock_client_context(mixin, mock_client_with_models):
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
assert await mixin.check_model_availability("final-mock-model-id")
|
assert await mixin.check_model_availability("final-mock-model-id")
|
||||||
|
|
@ -444,7 +454,7 @@ class TestOpenAIMixinModelRegistration:
|
||||||
|
|
||||||
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||||
"""Test model registration with allowed_models filtering"""
|
"""Test model registration with allowed_models filtering"""
|
||||||
mixin.allowed_models = {"some-mock-model-id"}
|
mixin.config.allowed_models = ["some-mock-model-id"]
|
||||||
|
|
||||||
# Test with allowed model
|
# Test with allowed model
|
||||||
allowed_model = Model(
|
allowed_model = Model(
|
||||||
|
|
@ -598,7 +608,7 @@ class TestOpenAIMixinCustomListProviderModelIds:
|
||||||
mixin = CustomListProviderModelIdsImplementation(
|
mixin = CustomListProviderModelIdsImplementation(
|
||||||
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
|
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
|
||||||
)
|
)
|
||||||
mixin.allowed_models = ["model-1"]
|
mixin.config.allowed_models = ["model-1"]
|
||||||
|
|
||||||
result = await mixin.list_models()
|
result = await mixin.list_models()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue