fix: improve model availability checks

- Add has_model method to ModelsRoutingTable for checking pre-registered models
- Update check_model_availability to check model_store before provider APIs
This commit is contained in:
Akram Ben Aissi 2025-10-06 23:20:17 +02:00
parent 696fefbf17
commit 0d5ce21764
4 changed files with 64 additions and 4 deletions

View file

@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
async def has_model(self, model_id: str) -> bool:
"""
Check if a model exists in the routing table.
:param model_id: The model identifier to check
:return: True if the model exists, False otherwise
"""
try:
await lookup_model(self, model_id)
return True
except ModelNotFoundError:
return False
async def register_model(
self,
model_id: str,

View file

@ -471,11 +471,17 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
Check if a specific model is available from the provider's /v1/models or pre-registered.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: True if the model is available dynamically or pre-registered, False otherwise.
"""
# First check if the model is pre-registered in the model store
if hasattr(self, "model_store") and self.model_store:
if await self.model_store.has_model(model):
return True
# Then check the provider's dynamic model cache
if not self._model_cache:
await self.list_models()
return model in self._model_cache

View file

@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
# Test has_model
assert await table.has_model("test_provider/test-model")
assert await table.has_model("test_provider/test-model-2")
assert not await table.has_model("non-existent-model")
assert not await table.has_model("test_provider/non-existent-model")
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")

View file

@ -44,11 +44,12 @@ def mixin():
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
# Mock model_store with async methods
mock_model_store = AsyncMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
mixin_instance.model_store = mock_model_store
return mixin_instance
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
assert len(mixin._model_cache) == 3
async def test_check_model_availability_with_pre_registered_model(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability returns True for pre-registered models in model_store"""
# Mock model_store.has_model to return True for a specific model
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=True)
mixin.model_store = mock_model_store
# Test that pre-registered model is found without calling the provider's API
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability falls back to provider when model not in store"""
# Mock model_store.has_model to return False
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=False)
mixin.model_store = mock_model_store
# Test that it falls back to provider's model cache
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""