mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 06:52:28 +00:00
query_available_models() -> list[str] -> check_model_availability(model) -> bool
This commit is contained in:
parent
d035fe93c6
commit
770c20e36b
2 changed files with 27 additions and 27 deletions
|
|
@ -82,35 +82,35 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||||
|
|
||||||
async def query_available_models(self) -> list[str]:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Return a list of available models.
|
Check if a specific model is available from the provider (non-static check).
|
||||||
|
|
||||||
This is for subclassing purposes, so providers can lookup a list of
|
This is for subclassing purposes, so providers can check if a specific
|
||||||
of currently available models.
|
model is currently available for use through dynamic means (e.g., API calls).
|
||||||
|
|
||||||
This is combined with the statically configured model entries in
|
This method should NOT check statically configured model entries in
|
||||||
`self.alias_to_provider_id_map` to determine which models are
|
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||||
available for registration.
|
|
||||||
|
|
||||||
Default implementation returns no models.
|
Default implementation returns False (no dynamic models available).
|
||||||
|
|
||||||
:return: A list of model identifiers (provider_model_ids).
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
"""
|
"""
|
||||||
return []
|
return False
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
# Check if model is supported in static configuration
|
# Check if model is supported in static configuration
|
||||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
# If not found in static config, check if it's available from provider
|
# If not found in static config, check if it's available dynamically from provider
|
||||||
if not supported_model_id:
|
if not supported_model_id:
|
||||||
available_models = await self.query_available_models()
|
if await self.check_model_availability(model.provider_resource_id):
|
||||||
if model.provider_resource_id in available_models:
|
|
||||||
supported_model_id = model.provider_resource_id
|
supported_model_id = model.provider_resource_id
|
||||||
else:
|
else:
|
||||||
# Combine static and dynamic models for error message
|
# note: we cannot provide a complete list of supported models without
|
||||||
all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models
|
# getting a complete list from the provider, so we return "..."
|
||||||
|
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||||
|
|
||||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||||
|
|
|
||||||
|
|
@ -94,8 +94,8 @@ class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
|
||||||
super().__init__(model_entries)
|
super().__init__(model_entries)
|
||||||
self._available_models = available_models
|
self._available_models = available_models
|
||||||
|
|
||||||
async def query_available_models(self) -> list[str]:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
return self._available_models
|
return model in self._available_models
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -195,16 +195,16 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_model_from_query_available_models(
|
async def test_register_model_from_check_model_availability(
|
||||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that models returned by query_available_models can be registered."""
|
"""Test that models returned by check_model_availability can be registered."""
|
||||||
# Verify the model is not in static config
|
# Verify the model is not in static config
|
||||||
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
|
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
|
||||||
|
|
||||||
# But it should be available via query_available_models
|
# But it should be available via check_model_availability
|
||||||
available_models = await helper_with_dynamic_models.query_available_models()
|
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
|
||||||
assert dynamic_model.provider_resource_id in available_models
|
assert is_available
|
||||||
|
|
||||||
# Registration should succeed
|
# Registration should succeed
|
||||||
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
|
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
|
||||||
|
|
@ -224,17 +224,17 @@ async def test_register_model_not_in_static_or_dynamic(
|
||||||
# Verify the model is not in static config
|
# Verify the model is not in static config
|
||||||
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
|
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
|
||||||
|
|
||||||
# And not in dynamic models
|
# And not available via check_model_availability
|
||||||
available_models = await helper_with_dynamic_models.query_available_models()
|
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
|
||||||
assert unknown_model.provider_resource_id not in available_models
|
assert not is_available
|
||||||
|
|
||||||
# Registration should fail with comprehensive error message
|
# Registration should fail with comprehensive error message
|
||||||
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
|
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
|
||||||
await helper_with_dynamic_models.register_model(unknown_model)
|
await helper_with_dynamic_models.register_model(unknown_model)
|
||||||
|
|
||||||
# Error should include both static and dynamic models
|
# Error should include static models and "..." for dynamic models
|
||||||
error_str = str(exc_info.value)
|
error_str = str(exc_info.value)
|
||||||
assert "dynamic-provider-id" in error_str # dynamic model should be in error
|
assert "..." in error_str # "..." should be in error message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue