mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 19:32:36 +00:00
fix: improve model availability checks
- Add has_model method to ModelsRoutingTable for checking pre-registered models - Update check_model_availability to check model_store before provider APIs
This commit is contained in:
parent
696fefbf17
commit
0d5ce21764
4 changed files with 64 additions and 4 deletions
|
|
@ -44,11 +44,12 @@ def mixin():
|
|||
config = RemoteInferenceProviderConfig()
|
||||
mixin_instance = OpenAIMixinImpl(config=config)
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
# Mock model_store with async methods
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
|
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
|
||||
assert len(mixin._model_cache) == 3
|
||||
|
||||
async def test_check_model_availability_with_pre_registered_model(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability returns True for pre-registered models in model_store"""
|
||||
# Mock model_store.has_model to return True for a specific model
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=True)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that pre-registered model is found without calling the provider's API
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability falls back to provider when model not in store"""
|
||||
# Mock model_store.has_model to return False
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=False)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that it falls back to provider's model cache
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
"""Test cases for cache behavior and edge cases"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue