mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(models): list models available via provider_data header (#3968)
## Summary
When users provide API keys via `X-LlamaStack-Provider-Data` header,
`models.list()` now returns models they can access from those providers,
not just pre-registered models from the registry.
This complements the routing fix from f88416ef8 which enabled inference
calls with `provider_id/model_id` format for unregistered models. Users
can now discover which models are available to them before making
inference requests.
The implementation reuses
`NeedsRequestProviderData.get_request_provider_data()` to validate
credentials, then dynamically fetches models from providers without
caching them since they're user-specific. Registry models take
precedence to respect any pre-configured aliases.
## Test Script
```python
#!/usr/bin/env python3
import json
import os
from openai import OpenAI
# Test 1: Without provider_data header
client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="dummy")
models = client.models.list()
anthropic_without = [m.id for m in models.data if m.id and "anthropic" in m.id]
print(f"Without header: {len(models.data)} models, {len(anthropic_without)} anthropic")
# Test 2: With provider_data header containing Anthropic API key
anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]
client_with_key = OpenAI(
base_url="http://localhost:8321/v1/openai/v1",
api_key="dummy",
default_headers={
"X-LlamaStack-Provider-Data": json.dumps({"anthropic_api_key": anthropic_api_key})
}
)
models_with_key = client_with_key.models.list()
anthropic_with = [m.id for m in models_with_key.data if m.id and "anthropic" in m.id]
print(f"With header: {len(models_with_key.data)} models, {len(anthropic_with)} anthropic")
print(f"Anthropic models: {anthropic_with}")
assert len(anthropic_with) > len(anthropic_without), "Should have more anthropic models with API key"
print("\n✓ Test passed!")
```
Run with a stack that has Anthropic provider configured (but without API
key in config):
```bash
ANTHROPIC_API_KEY=sk-ant-... python test_provider_data_models.py
```
This commit is contained in:
parent
c9d4b6c54f
commit
da8f014b96
3 changed files with 88 additions and 5 deletions
|
|
@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
RegistryEntrySource,
|
RegistryEntrySource,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||||
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
@ -42,11 +44,90 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
await self.update_registered_models(provider_id, models)
|
await self.update_registered_models(provider_id, models)
|
||||||
|
|
||||||
|
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||||
|
"""
|
||||||
|
Fetch models from providers that have credentials in the current request's provider_data.
|
||||||
|
|
||||||
|
This allows users to see models available to them from providers that require
|
||||||
|
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||||
|
|
||||||
|
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||||
|
cache them in the registry since they are user-specific.
|
||||||
|
"""
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
if not provider_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
dynamic_models = []
|
||||||
|
|
||||||
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
|
# Check if this provider supports provider_data
|
||||||
|
if not isinstance(provider, NeedsRequestProviderData):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
|
||||||
|
spec = getattr(provider, "__provider_spec__", None)
|
||||||
|
if not spec or not getattr(spec, "provider_data_validator", None):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate provider_data silently - we're speculatively checking all providers
|
||||||
|
# so validation failures are expected when user didn't provide keys for this provider
|
||||||
|
try:
|
||||||
|
validator = instantiate_class_type(spec.provider_data_validator)
|
||||||
|
validator(**provider_data)
|
||||||
|
except Exception:
|
||||||
|
# User didn't provide credentials for this provider - skip silently
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validation succeeded! User has credentials for this provider
|
||||||
|
# Now try to list models
|
||||||
|
try:
|
||||||
|
models = await provider.list_models()
|
||||||
|
if not models:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||||
|
for model in models:
|
||||||
|
# Only add prefix if model identifier doesn't already have it
|
||||||
|
if not model.identifier.startswith(f"{provider_id}/"):
|
||||||
|
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||||
|
|
||||||
|
dynamic_models.append(model)
|
||||||
|
|
||||||
|
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return dynamic_models
|
||||||
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
# Get models from registry
|
||||||
|
registry_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# Get additional models available via provider_data (user-specific, not cached)
|
||||||
|
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||||
|
|
||||||
|
# Combine, avoiding duplicates (registry takes precedence)
|
||||||
|
registry_identifiers = {m.identifier for m in registry_models}
|
||||||
|
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||||
|
|
||||||
|
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
models = await self.get_all_with_type("model")
|
# Get models from registry
|
||||||
|
registry_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# Get additional models available via provider_data (user-specific, not cached)
|
||||||
|
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||||
|
|
||||||
|
# Combine, avoiding duplicates (registry takes precedence)
|
||||||
|
registry_identifiers = {m.identifier for m in registry_models}
|
||||||
|
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||||
|
|
||||||
|
all_models = registry_models + unique_dynamic_models
|
||||||
|
|
||||||
openai_models = [
|
openai_models = [
|
||||||
OpenAIModel(
|
OpenAIModel(
|
||||||
id=model.identifier,
|
id=model.identifier,
|
||||||
|
|
@ -54,7 +135,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
owned_by="llama_stack",
|
owned_by="llama_stack",
|
||||||
)
|
)
|
||||||
for model in models
|
for model in all_models
|
||||||
]
|
]
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
||||||
return "https://api.anthropic.com/v1"
|
return "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
api_key = self._get_api_key_from_config_or_provider_data()
|
||||||
|
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]
|
||||||
|
|
|
||||||
|
|
@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
# Filter out None values from endpoint names
|
# Filter out None values from endpoint names
|
||||||
|
api_token = self._get_api_key_from_config_or_provider_data()
|
||||||
return [
|
return [
|
||||||
endpoint.name # type: ignore[misc]
|
endpoint.name # type: ignore[misc]
|
||||||
for endpoint in WorkspaceClient(
|
for endpoint in WorkspaceClient(
|
||||||
host=self.config.url, token=self.get_api_key()
|
host=self.config.url, token=api_token
|
||||||
).serving_endpoints.list() # TODO: this is not async
|
).serving_endpoints.list() # TODO: this is not async
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue