mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 18:08:09 +00:00
feat: add infrastructure to allow inference model discovery (#2710)
# What does this PR do? inference providers each have a static list of supported / known models. some also have access to a dynamic list of currently available models. this change gives prodivers using the ModelRegistryHelper the ability to combine their static and dynamic lists. for instance, OpenAIInferenceAdapter can implement ``` def query_available_models(self) -> list[str]: return [entry.model for entry in self.openai_client.models.list()] ``` to augment its static list w/ a current list from openai. ## Test Plan scripts/unit-test.sh
This commit is contained in:
parent
a7ed86181c
commit
f731f369a2
2 changed files with 122 additions and 2 deletions
|
@ -83,9 +83,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider (non-static check).
|
||||
|
||||
This is for subclassing purposes, so providers can check if a specific
|
||||
model is currently available for use through dynamic means (e.g., API calls).
|
||||
|
||||
This method should NOT check statically configured model entries in
|
||||
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||
|
||||
Default implementation returns False (no dynamic models available).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
# Check if model is supported in static configuration
|
||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
# If not found in static config, check if it's available dynamically from provider
|
||||
if not supported_model_id:
|
||||
if await self.check_model_availability(model.provider_resource_id):
|
||||
supported_model_id = model.provider_resource_id
|
||||
else:
|
||||
# note: we cannot provide a complete list of supported models without
|
||||
# getting a complete list from the provider, so we return "..."
|
||||
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
|
@ -114,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
# Register the model alias, ensuring it maps to the correct provider model id
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue