query_available_models() -> list[str] -> check_model_availability(model) -> bool

This commit is contained in:
Matthew Farrellee 2025-07-14 14:08:44 -04:00
parent d035fe93c6
commit 770c20e36b
2 changed files with 27 additions and 27 deletions

View file

@ -82,35 +82,35 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_llama_model(self, provider_model_id: str) -> str | None: def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None) return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
""" """
Return a list of available models. Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can lookup a list of This is for subclassing purposes, so providers can check if a specific
of currently available models. model is currently available for use through dynamic means (e.g., API calls).
This is combined with the statically configured model entries in This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` to determine which models are `self.alias_to_provider_id_map` - that is handled separately in register_model.
available for registration.
Default implementation returns no models. Default implementation returns False (no dynamic models available).
:return: A list of model identifiers (provider_model_ids). :param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
""" """
return [] return False
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
# Check if model is supported in static configuration # Check if model is supported in static configuration
supported_model_id = self.get_provider_model_id(model.provider_resource_id) supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available from provider # If not found in static config, check if it's available dynamically from provider
if not supported_model_id: if not supported_model_id:
available_models = await self.query_available_models() if await self.check_model_availability(model.provider_resource_id):
if model.provider_resource_id in available_models:
supported_model_id = model.provider_resource_id supported_model_id = model.provider_resource_id
else: else:
# Combine static and dynamic models for error message # note: we cannot provide a complete list of supported models without
all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models # getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models) raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id) provider_resource_id = self.get_provider_model_id(model.model_id)

View file

@ -94,8 +94,8 @@ class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
super().__init__(model_entries) super().__init__(model_entries)
self._available_models = available_models self._available_models = available_models
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
return self._available_models return model in self._available_models
@pytest.fixture @pytest.fixture
@ -195,16 +195,16 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_model_from_query_available_models( async def test_register_model_from_check_model_availability(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None: ) -> None:
"""Test that models returned by query_available_models can be registered.""" """Test that models returned by check_model_availability can be registered."""
# Verify the model is not in static config # Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
# But it should be available via query_available_models # But it should be available via check_model_availability
available_models = await helper_with_dynamic_models.query_available_models() is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
assert dynamic_model.provider_resource_id in available_models assert is_available
# Registration should succeed # Registration should succeed
registered_model = await helper_with_dynamic_models.register_model(dynamic_model) registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
@ -224,17 +224,17 @@ async def test_register_model_not_in_static_or_dynamic(
# Verify the model is not in static config # Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
# And not in dynamic models # And not available via check_model_availability
available_models = await helper_with_dynamic_models.query_available_models() is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
assert unknown_model.provider_resource_id not in available_models assert not is_available
# Registration should fail with comprehensive error message # Registration should fail with comprehensive error message
with pytest.raises(Exception) as exc_info: # UnsupportedModelError with pytest.raises(Exception) as exc_info: # UnsupportedModelError
await helper_with_dynamic_models.register_model(unknown_model) await helper_with_dynamic_models.register_model(unknown_model)
# Error should include both static and dynamic models # Error should include static models and "..." for dynamic models
error_str = str(exc_info.value) error_str = str(exc_info.value)
assert "dynamic-provider-id" in error_str # dynamic model should be in error assert "..." in error_str # "..." should be in error message
@pytest.mark.asyncio @pytest.mark.asyncio