mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat: add infrastructure to allow inference model discovery
inference providers each have a static list of supported / known models. some also have access to a dynamic list of currently available models. this change gives prodivers using the ModelRegistryHelper the ability to combine their static and dynamic lists. for instance, OpenAIInferenceAdapter can implement ``` def query_available_models(self) -> list[str]: return [entry.model for entry in self.openai_client.models.list()] ``` to augment its static list w/ a current list from openai.
This commit is contained in:
parent
cd0ad21111
commit
d035fe93c6
2 changed files with 125 additions and 2 deletions
|
@ -82,9 +82,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def query_available_models(self) -> list[str]:
|
||||
"""
|
||||
Return a list of available models.
|
||||
|
||||
This is for subclassing purposes, so providers can lookup a list of
|
||||
of currently available models.
|
||||
|
||||
This is combined with the statically configured model entries in
|
||||
`self.alias_to_provider_id_map` to determine which models are
|
||||
available for registration.
|
||||
|
||||
Default implementation returns no models.
|
||||
|
||||
:return: A list of model identifiers (provider_model_ids).
|
||||
"""
|
||||
return []
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
# Check if model is supported in static configuration
|
||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
# If not found in static config, check if it's available from provider
|
||||
if not supported_model_id:
|
||||
available_models = await self.query_available_models()
|
||||
if model.provider_resource_id in available_models:
|
||||
supported_model_id = model.provider_resource_id
|
||||
else:
|
||||
# Combine static and dynamic models for error message
|
||||
all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models
|
||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
|
@ -113,6 +141,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
# Register the model alias, ensuring it maps to the correct provider model id
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
|
|
@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
|
|||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
|
||||
"""Test helper that simulates a provider with dynamically available models."""
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
|
||||
super().__init__(model_entries)
|
||||
self._available_models = available_models
|
||||
|
||||
async def query_available_models(self) -> list[str]:
|
||||
return self._available_models
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_model() -> Model:
|
||||
"""A model that's not in static config but available dynamically."""
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="dynamic-model",
|
||||
provider_resource_id="dynamic-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper_with_dynamic_models(
|
||||
known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model
|
||||
) -> MockModelRegistryHelperWithDynamicModels:
|
||||
"""Helper that includes dynamically available models."""
|
||||
return MockModelRegistryHelperWithDynamicModels(
|
||||
[known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||
|
@ -161,3 +192,66 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m
|
|||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.provider_resource_id)
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_from_query_available_models(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||
) -> None:
|
||||
"""Test that models returned by query_available_models can be registered."""
|
||||
# Verify the model is not in static config
|
||||
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
|
||||
|
||||
# But it should be available via query_available_models
|
||||
available_models = await helper_with_dynamic_models.query_available_models()
|
||||
assert dynamic_model.provider_resource_id in available_models
|
||||
|
||||
# Registration should succeed
|
||||
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
|
||||
assert registered_model == dynamic_model
|
||||
|
||||
# Model should now be registered and accessible
|
||||
assert (
|
||||
helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_not_in_static_or_dynamic(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
|
||||
) -> None:
|
||||
"""Test that models not in static config or dynamic models are rejected."""
|
||||
# Verify the model is not in static config
|
||||
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
|
||||
|
||||
# And not in dynamic models
|
||||
available_models = await helper_with_dynamic_models.query_available_models()
|
||||
assert unknown_model.provider_resource_id not in available_models
|
||||
|
||||
# Registration should fail with comprehensive error message
|
||||
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
|
||||
await helper_with_dynamic_models.register_model(unknown_model)
|
||||
|
||||
# Error should include both static and dynamic models
|
||||
error_str = str(exc_info.value)
|
||||
assert "dynamic-provider-id" in error_str # dynamic model should be in error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_alias_for_dynamic_model(
|
||||
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||
) -> None:
|
||||
"""Test that we can register an alias that maps to a dynamically available model."""
|
||||
# Create a model with a different identifier but same provider_resource_id
|
||||
alias_model = Model(
|
||||
provider_id=dynamic_model.provider_id,
|
||||
identifier="dynamic-model-alias",
|
||||
provider_resource_id=dynamic_model.provider_resource_id,
|
||||
)
|
||||
|
||||
# Registration should succeed since the provider_resource_id is available dynamically
|
||||
registered_model = await helper_with_dynamic_models.register_model(alias_model)
|
||||
assert registered_model == alias_model
|
||||
|
||||
# Both the original provider_resource_id and the new alias should work
|
||||
assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue