feat(models): list models available via provider_data header (#3968)

## Summary

When users provide API keys via `X-LlamaStack-Provider-Data` header,
`models.list()` now returns models they can access from those providers,
not just pre-registered models from the registry.

This complements the routing fix from f88416ef8 which enabled inference
calls with `provider_id/model_id` format for unregistered models. Users
can now discover which models are available to them before making
inference requests.

The implementation reuses
`NeedsRequestProviderData.get_request_provider_data()` to validate
credentials, then dynamically fetches models from providers without
caching them since they're user-specific. Registry models take
precedence to respect any pre-configured aliases.

## Test Script

```python
#!/usr/bin/env python3
import json
import os
from openai import OpenAI

# Test 1: Without provider_data header
client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="dummy")
models = client.models.list()
anthropic_without = [m.id for m in models.data if m.id and "anthropic" in m.id]
print(f"Without header: {len(models.data)} models, {len(anthropic_without)} anthropic")

# Test 2: With provider_data header containing Anthropic API key
anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]
client_with_key = OpenAI(
    base_url="http://localhost:8321/v1/openai/v1",
    api_key="dummy",
    default_headers={
        "X-LlamaStack-Provider-Data": json.dumps({"anthropic_api_key": anthropic_api_key})
    }
)
models_with_key = client_with_key.models.list()
anthropic_with = [m.id for m in models_with_key.data if m.id and "anthropic" in m.id]
print(f"With header: {len(models_with_key.data)} models, {len(anthropic_with)} anthropic")
print(f"Anthropic models: {anthropic_with}")

assert len(anthropic_with) > len(anthropic_without), "Should have more anthropic models with API key"
print("\n✓ Test passed!")
```

Run with a stack that has Anthropic provider configured (but without API
key in config):
```bash
ANTHROPIC_API_KEY=sk-ant-... python test_provider_data_models.py
```
This commit is contained in:
Ashwin Bharambe 2025-10-29 14:03:03 -07:00 committed by GitHub
parent c9d4b6c54f
commit da8f014b96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 88 additions and 5 deletions

View file

@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource, RegistryEntrySource,
) )
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
@ -42,11 +44,90 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models) await self.update_registered_models(provider_id, models)
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
"""
Fetch models from providers that have credentials in the current request's provider_data.
This allows users to see models available to them from providers that require
per-request API keys (via X-LlamaStack-Provider-Data header).
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
cache them in the registry since they are user-specific.
"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return []
dynamic_models = []
for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data
if not isinstance(provider, NeedsRequestProviderData):
continue
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
spec = getattr(provider, "__provider_spec__", None)
if not spec or not getattr(spec, "provider_data_validator", None):
continue
# Validate provider_data silently - we're speculatively checking all providers
# so validation failures are expected when user didn't provide keys for this provider
try:
validator = instantiate_class_type(spec.provider_data_validator)
validator(**provider_data)
except Exception:
# User didn't provide credentials for this provider - skip silently
continue
# Validation succeeded! User has credentials for this provider
# Now try to list models
try:
models = await provider.list_models()
if not models:
continue
# Ensure models have fully qualified identifiers with provider_id prefix
for model in models:
# Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}"
dynamic_models.append(model)
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
continue
return dynamic_models
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) # Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
return ListModelsResponse(data=registry_models + unique_dynamic_models)
async def openai_list_models(self) -> OpenAIListModelsResponse: async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model") # Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
all_models = registry_models + unique_dynamic_models
openai_models = [ openai_models = [
OpenAIModel( OpenAIModel(
id=model.identifier, id=model.identifier,
@ -54,7 +135,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
created=int(time.time()), created=int(time.time()),
owned_by="llama_stack", owned_by="llama_stack",
) )
for model in models for model in all_models
] ]
return OpenAIListModelsResponse(data=openai_models) return OpenAIListModelsResponse(data=openai_models)

View file

@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
return "https://api.anthropic.com/v1" return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()] api_key = self._get_api_key_from_config_or_provider_data()
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]

View file

@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names # Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
return [ return [
endpoint.name # type: ignore[misc] endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient( for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key() host=self.config.url, token=api_token
).serving_endpoints.list() # TODO: this is not async ).serving_endpoints.list() # TODO: this is not async
] ]