feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod

- added missing tests for Fireworks, Anthropic, Gemini, SambaNova, and vLLM
This commit is contained in:
Matthew Farrellee 2025-10-08 09:31:17 -04:00
parent 5d711d4bcb
commit bb95c1a7c5
10 changed files with 125 additions and 8 deletions

View file

@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
RemoteProviderSpec(
@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
RemoteProviderSpec(
@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
RemoteProviderSpec(
@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
RemoteProviderSpec(

View file

@ -15,6 +15,8 @@ from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
provider_data_api_key_field: str = "cerebras_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(

View file

@ -6,12 +6,19 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class DatabricksProviderDataValidator(BaseModel):
databricks_api_token: str | None = Field(
default=None,
description="API token for Databricks models",
)
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(

View file

@ -21,6 +21,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
provider_data_api_key_field: str = "databricks_api_token"
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},

View file

@ -7,12 +7,19 @@
import os
from typing import Any
from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class NVIDIAProviderDataValidator(BaseModel):
nvidia_api_key: str | None = Field(
default=None,
description="API key for NVIDIA NIM models",
)
@json_schema_type
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""

View file

@ -24,6 +24,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
provider_data_api_key_field: str = "nvidia_api_key"
"""
NVIDIA Inference Adapter for Llama Stack.

View file

@ -6,12 +6,19 @@
from typing import Any
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class RunpodProviderDataValidator(BaseModel):
runpod_api_token: str | None = Field(
default=None,
description="API token for RunPod models",
)
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(

View file

@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin):
config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token"
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token

View file

@ -10,47 +10,124 @@ from unittest.mock import MagicMock
import pytest
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
"config_cls,adapter_cls,provider_data_validator,config_params",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
{},
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
{},
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
{},
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
{},
),
(
CerebrasImplConfig,
CerebrasInferenceAdapter,
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
{},
),
(
DatabricksImplConfig,
DatabricksInferenceAdapter,
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
{},
),
(
NVIDIAConfig,
NVIDIAInferenceAdapter,
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
{},
),
(
RunpodImplConfig,
RunpodInferenceAdapter,
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
{},
),
(
FireworksImplConfig,
FireworksInferenceAdapter,
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
{},
),
(
AnthropicConfig,
AnthropicInferenceAdapter,
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
{},
),
(
GeminiConfig,
GeminiInferenceAdapter,
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
{},
),
(
SambaNovaImplConfig,
SambaNovaInferenceAdapter,
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
{},
),
(
VLLMInferenceAdapterConfig,
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
},
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
"""Ensure the OpenAI provider does not cache api keys across client requests"""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(**config_params))
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator