mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
feat(models): list models available via provider_data header
When users provide API keys via X-LlamaStack-Provider-Data header, models.list()
now returns models they can access from those providers, not just pre-registered
models from the registry.
This complements the routing fix from f88416ef8 which enabled inference calls with
provider_id/model_id format for unregistered models. Users can now discover which
models are available to them before making inference requests.
The implementation reuses NeedsRequestProviderData.get_request_provider_data() to
validate credentials, then dynamically fetches models from providers without caching
them (since they're user-specific). Registry models take precedence over dynamic ones
to respect any pre-configured aliases.
This commit is contained in:
parent
b90c6a2c8b
commit
de6072fda7
1 changed files with 75 additions and 3 deletions
|
|
@ -13,6 +13,7 @@ from llama_stack.core.datatypes import (
|
|||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData, PROVIDER_DATA_VAR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
|
@ -42,11 +43,82 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||
"""
|
||||
Fetch models from providers that have credentials in the current request's provider_data.
|
||||
|
||||
This allows users to see models available to them from providers that require
|
||||
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||
|
||||
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||
cache them in the registry since they are user-specific.
|
||||
"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return []
|
||||
|
||||
dynamic_models = []
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
if not isinstance(provider, NeedsRequestProviderData):
|
||||
continue
|
||||
|
||||
# Try to get validated provider_data for this provider
|
||||
# Returns None if validation fails (missing keys) or if no provider_data exists
|
||||
validated_data = provider.get_request_provider_data()
|
||||
if validated_data is None:
|
||||
continue
|
||||
|
||||
# Validation succeeded! User has credentials for this provider
|
||||
# Now try to list models
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
continue
|
||||
|
||||
return dynamic_models
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
all_models = registry_models + unique_dynamic_models
|
||||
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
|
|
@ -54,7 +126,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
)
|
||||
for model in models
|
||||
for model in all_models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue