mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
fix: improve model availability checks
- Add has_model method to ModelsRoutingTable for checking pre-registered models - Update check_model_availability to check model_store before provider APIs
This commit is contained in:
parent
696fefbf17
commit
0d5ce21764
4 changed files with 64 additions and 4 deletions
|
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def has_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the routing table.
|
||||
|
||||
:param model_id: The model identifier to check
|
||||
:return: True if the model exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
await lookup_model(self, model_id)
|
||||
return True
|
||||
except ModelNotFoundError:
|
||||
return False
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -471,11 +471,17 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
Check if a specific model is available from the provider's /v1/models or pre-registered.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
:return: True if the model is available dynamically or pre-registered, False otherwise.
|
||||
"""
|
||||
# First check if the model is pre-registered in the model store
|
||||
if hasattr(self, "model_store") and self.model_store:
|
||||
if await self.model_store.has_model(model):
|
||||
return True
|
||||
|
||||
# Then check the provider's dynamic model cache
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
|
|
|||
|
|
@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
||||
assert non_existent is None
|
||||
|
||||
# Test has_model
|
||||
assert await table.has_model("test_provider/test-model")
|
||||
assert await table.has_model("test_provider/test-model-2")
|
||||
assert not await table.has_model("non-existent-model")
|
||||
assert not await table.has_model("test_provider/non-existent-model")
|
||||
|
||||
await table.unregister_model(model_id="test_provider/test-model")
|
||||
await table.unregister_model(model_id="test_provider/test-model-2")
|
||||
|
||||
|
|
|
|||
|
|
@ -44,11 +44,12 @@ def mixin():
|
|||
config = RemoteInferenceProviderConfig()
|
||||
mixin_instance = OpenAIMixinImpl(config=config)
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
# Mock model_store with async methods
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
|
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
|
||||
assert len(mixin._model_cache) == 3
|
||||
|
||||
async def test_check_model_availability_with_pre_registered_model(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability returns True for pre-registered models in model_store"""
|
||||
# Mock model_store.has_model to return True for a specific model
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=True)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that pre-registered model is found without calling the provider's API
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability falls back to provider when model not in store"""
|
||||
# Mock model_store.has_model to return False
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=False)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that it falls back to provider's model cache
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
"""Test cases for cache behavior and edge cases"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue