From da8f014b96917a32aa0369535d6edb1d1680e46d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 29 Oct 2025 14:03:03 -0700 Subject: [PATCH] feat(models): list models available via provider_data header (#3968) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary When users provide API keys via `X-LlamaStack-Provider-Data` header, `models.list()` now returns models they can access from those providers, not just pre-registered models from the registry. This complements the routing fix from f88416ef8 which enabled inference calls with `provider_id/model_id` format for unregistered models. Users can now discover which models are available to them before making inference requests. The implementation reuses `NeedsRequestProviderData.get_request_provider_data()` to validate credentials, then dynamically fetches models from providers without caching them since they're user-specific. Registry models take precedence to respect any pre-configured aliases. ## Test Script ```python #!/usr/bin/env python3 import json import os from openai import OpenAI # Test 1: Without provider_data header client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="dummy") models = client.models.list() anthropic_without = [m.id for m in models.data if m.id and "anthropic" in m.id] print(f"Without header: {len(models.data)} models, {len(anthropic_without)} anthropic") # Test 2: With provider_data header containing Anthropic API key anthropic_api_key = os.environ["ANTHROPIC_API_KEY"] client_with_key = OpenAI( base_url="http://localhost:8321/v1/openai/v1", api_key="dummy", default_headers={ "X-LlamaStack-Provider-Data": json.dumps({"anthropic_api_key": anthropic_api_key}) } ) models_with_key = client_with_key.models.list() anthropic_with = [m.id for m in models_with_key.data if m.id and "anthropic" in m.id] print(f"With header: {len(models_with_key.data)} models, {len(anthropic_with)} anthropic") print(f"Anthropic models: {anthropic_with}") assert len(anthropic_with) > len(anthropic_without), "Should have more anthropic models with API key" print("\n✓ Test passed!") ``` Run with a stack that has Anthropic provider configured (but without API key in config): ```bash ANTHROPIC_API_KEY=sk-ant-... python test_provider_data_models.py ``` --- src/llama_stack/core/routing_tables/models.py | 87 ++++++++++++++++++- .../remote/inference/anthropic/anthropic.py | 3 +- .../remote/inference/databricks/databricks.py | 3 +- 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/src/llama_stack/core/routing_tables/models.py b/src/llama_stack/core/routing_tables/models.py index 7e43d7273..be17be3d4 100644 --- a/src/llama_stack/core/routing_tables/models.py +++ b/src/llama_stack/core/routing_tables/models.py @@ -13,6 +13,8 @@ from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, ) +from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model @@ -42,11 +44,90 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) + async def _get_dynamic_models_from_provider_data(self) -> list[Model]: + """ + Fetch models from providers that have credentials in the current request's provider_data. + + This allows users to see models available to them from providers that require + per-request API keys (via X-LlamaStack-Provider-Data header). + + Returns models with fully qualified identifiers (provider_id/model_id) but does NOT + cache them in the registry since they are user-specific. + """ + provider_data = PROVIDER_DATA_VAR.get() + if not provider_data: + return [] + + dynamic_models = [] + + for provider_id, provider in self.impls_by_provider_id.items(): + # Check if this provider supports provider_data + if not isinstance(provider, NeedsRequestProviderData): + continue + + # Check if provider has a validator (some providers like ollama don't need per-request credentials) + spec = getattr(provider, "__provider_spec__", None) + if not spec or not getattr(spec, "provider_data_validator", None): + continue + + # Validate provider_data silently - we're speculatively checking all providers + # so validation failures are expected when user didn't provide keys for this provider + try: + validator = instantiate_class_type(spec.provider_data_validator) + validator(**provider_data) + except Exception: + # User didn't provide credentials for this provider - skip silently + continue + + # Validation succeeded! User has credentials for this provider + # Now try to list models + try: + models = await provider.list_models() + if not models: + continue + + # Ensure models have fully qualified identifiers with provider_id prefix + for model in models: + # Only add prefix if model identifier doesn't already have it + if not model.identifier.startswith(f"{provider_id}/"): + model.identifier = f"{provider_id}/{model.provider_resource_id}" + + dynamic_models.append(model) + + logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data") + + except Exception as e: + logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}") + continue + + return dynamic_models + async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + # Get models from registry + registry_models = await self.get_all_with_type("model") + + # Get additional models available via provider_data (user-specific, not cached) + dynamic_models = await self._get_dynamic_models_from_provider_data() + + # Combine, avoiding duplicates (registry takes precedence) + registry_identifiers = {m.identifier for m in registry_models} + unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers] + + return ListModelsResponse(data=registry_models + unique_dynamic_models) async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") + # Get models from registry + registry_models = await self.get_all_with_type("model") + + # Get additional models available via provider_data (user-specific, not cached) + dynamic_models = await self._get_dynamic_models_from_provider_data() + + # Combine, avoiding duplicates (registry takes precedence) + registry_identifiers = {m.identifier for m in registry_models} + unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers] + + all_models = registry_models + unique_dynamic_models + openai_models = [ OpenAIModel( id=model.identifier, @@ -54,7 +135,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): created=int(time.time()), owned_by="llama_stack", ) - for model in models + for model in all_models ] return OpenAIListModelsResponse(data=openai_models) diff --git a/src/llama_stack/providers/remote/inference/anthropic/anthropic.py b/src/llama_stack/providers/remote/inference/anthropic/anthropic.py index dc9d8fb40..112b70524 100644 --- a/src/llama_stack/providers/remote/inference/anthropic/anthropic.py +++ b/src/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin): return "https://api.anthropic.com/v1" async def list_provider_model_ids(self) -> Iterable[str]: - return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()] + api_key = self._get_api_key_from_config_or_provider_data() + return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()] diff --git a/src/llama_stack/providers/remote/inference/databricks/databricks.py b/src/llama_stack/providers/remote/inference/databricks/databricks.py index 8a8c5d4e3..636241383 100644 --- a/src/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/src/llama_stack/providers/remote/inference/databricks/databricks.py @@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin): async def list_provider_model_ids(self) -> Iterable[str]: # Filter out None values from endpoint names + api_token = self._get_api_key_from_config_or_provider_data() return [ endpoint.name # type: ignore[misc] for endpoint in WorkspaceClient( - host=self.config.url, token=self.get_api_key() + host=self.config.url, token=api_token ).serving_endpoints.list() # TODO: this is not async ]