feat(models): list models available via provider_data header

When users provide API keys via X-LlamaStack-Provider-Data header, models.list()
now returns models they can access from those providers, not just pre-registered
models from the registry.

This complements the routing fix from f88416ef8 which enabled inference calls with
provider_id/model_id format for unregistered models. Users can now discover which
models are available to them before making inference requests.

The implementation reuses NeedsRequestProviderData.get_request_provider_data() to
validate credentials, then dynamically fetches models from providers without caching
them (since they're user-specific). Registry models take precedence over dynamic ones
to respect any pre-configured aliases.
This commit is contained in:
Ashwin Bharambe 2025-10-29 10:49:19 -07:00
parent b90c6a2c8b
commit de6072fda7

View file

@ -13,6 +13,7 @@ from llama_stack.core.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource, RegistryEntrySource,
) )
from llama_stack.core.request_headers import NeedsRequestProviderData, PROVIDER_DATA_VAR
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
@ -42,11 +43,82 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models) await self.update_registered_models(provider_id, models)
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
"""
Fetch models from providers that have credentials in the current request's provider_data.
This allows users to see models available to them from providers that require
per-request API keys (via X-LlamaStack-Provider-Data header).
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
cache them in the registry since they are user-specific.
"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return []
dynamic_models = []
for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data
if not isinstance(provider, NeedsRequestProviderData):
continue
# Try to get validated provider_data for this provider
# Returns None if validation fails (missing keys) or if no provider_data exists
validated_data = provider.get_request_provider_data()
if validated_data is None:
continue
# Validation succeeded! User has credentials for this provider
# Now try to list models
try:
models = await provider.list_models()
if not models:
continue
# Ensure models have fully qualified identifiers with provider_id prefix
for model in models:
# Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}"
dynamic_models.append(model)
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
continue
return dynamic_models
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) # Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
return ListModelsResponse(data=registry_models + unique_dynamic_models)
async def openai_list_models(self) -> OpenAIListModelsResponse: async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model") # Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
all_models = registry_models + unique_dynamic_models
openai_models = [ openai_models = [
OpenAIModel( OpenAIModel(
id=model.identifier, id=model.identifier,
@ -54,7 +126,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
created=int(time.time()), created=int(time.time()),
owned_by="llama_stack", owned_by="llama_stack",
) )
for model in models for model in all_models
] ]
return OpenAIListModelsResponse(data=openai_models) return OpenAIListModelsResponse(data=openai_models)