mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
## Summary
When users provide API keys via `X-LlamaStack-Provider-Data` header,
`models.list()` now returns models they can access from those providers,
not just pre-registered models from the registry.
This complements the routing fix from f88416ef8 which enabled inference
calls with `provider_id/model_id` format for unregistered models. Users
can now discover which models are available to them before making
inference requests.
The implementation reuses
`NeedsRequestProviderData.get_request_provider_data()` to validate
credentials, then dynamically fetches models from providers without
caching them since they're user-specific. Registry models take
precedence to respect any pre-configured aliases.
## Test Script
```python
#!/usr/bin/env python3
import json
import os
from openai import OpenAI
# Test 1: Without provider_data header
client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="dummy")
models = client.models.list()
anthropic_without = [m.id for m in models.data if m.id and "anthropic" in m.id]
print(f"Without header: {len(models.data)} models, {len(anthropic_without)} anthropic")
# Test 2: With provider_data header containing Anthropic API key
anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]
client_with_key = OpenAI(
base_url="http://localhost:8321/v1/openai/v1",
api_key="dummy",
default_headers={
"X-LlamaStack-Provider-Data": json.dumps({"anthropic_api_key": anthropic_api_key})
}
)
models_with_key = client_with_key.models.list()
anthropic_with = [m.id for m in models_with_key.data if m.id and "anthropic" in m.id]
print(f"With header: {len(models_with_key.data)} models, {len(anthropic_with)} anthropic")
print(f"Anthropic models: {anthropic_with}")
assert len(anthropic_with) > len(anthropic_without), "Should have more anthropic models with API key"
print("\n✓ Test passed!")
```
Run with a stack that has Anthropic provider configured (but without API
key in config):
```bash
ANTHROPIC_API_KEY=sk-ant-... python test_provider_data_models.py
```
48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import Iterable
|
|
|
|
from databricks.sdk import WorkspaceClient
|
|
|
|
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|
|
|
from .config import DatabricksImplConfig
|
|
|
|
logger = get_logger(name=__name__, category="inference::databricks")
|
|
|
|
|
|
class DatabricksInferenceAdapter(OpenAIMixin):
|
|
config: DatabricksImplConfig
|
|
|
|
provider_data_api_key_field: str = "databricks_api_token"
|
|
|
|
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
|
embedding_model_metadata: dict[str, dict[str, int]] = {
|
|
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
|
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
|
}
|
|
|
|
def get_base_url(self) -> str:
|
|
return f"{self.config.url}/serving-endpoints"
|
|
|
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
|
# Filter out None values from endpoint names
|
|
api_token = self._get_api_key_from_config_or_provider_data()
|
|
return [
|
|
endpoint.name # type: ignore[misc]
|
|
for endpoint in WorkspaceClient(
|
|
host=self.config.url, token=api_token
|
|
).serving_endpoints.list() # TODO: this is not async
|
|
]
|
|
|
|
async def openai_completion(
|
|
self,
|
|
params: OpenAICompletionRequestWithExtraBody,
|
|
) -> OpenAICompletion:
|
|
raise NotImplementedError()
|