From bb95c1a7c5bddf42427ee1bb2c3ec1d0bc17298e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 8 Oct 2025 09:31:17 -0400 Subject: [PATCH] feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod - added missing tests for Fireworks, Anthropic, Gemini, SambaNova, and vLLM --- llama_stack/providers/registry/inference.py | 4 + .../remote/inference/cerebras/cerebras.py | 2 + .../remote/inference/cerebras/config.py | 9 +- .../remote/inference/databricks/config.py | 9 +- .../remote/inference/databricks/databricks.py | 2 + .../remote/inference/nvidia/config.py | 9 +- .../remote/inference/nvidia/nvidia.py | 2 + .../remote/inference/runpod/config.py | 9 +- .../remote/inference/runpod/runpod.py | 2 + .../test_inference_client_caching.py | 85 ++++++++++++++++++- 10 files changed, 125 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f89565892..6b69038c6 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.cerebras", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), RemoteProviderSpec( @@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=["databricks-sdk"], module="llama_stack.providers.remote.inference.databricks", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), RemoteProviderSpec( @@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.nvidia", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), RemoteProviderSpec( @@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.runpod", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 11ef218a1..291336f86 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -15,6 +15,8 @@ from .config import CerebrasImplConfig class CerebrasInferenceAdapter(OpenAIMixin): config: CerebrasImplConfig + provider_data_api_key_field: str = "cerebras_api_key" + def get_api_key(self) -> str: return self.config.api_key.get_secret_value() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 40db38935..dbab60a4b 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type DEFAULT_BASE_URL = "https://api.cerebras.ai" +class CerebrasProviderDataValidator(BaseModel): + cerebras_api_key: str | None = Field( + default=None, + description="API key for Cerebras models", + ) + + @json_schema_type class CerebrasImplConfig(RemoteInferenceProviderConfig): base_url: str = Field( diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index 68e94151e..279e741be 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class DatabricksProviderDataValidator(BaseModel): + databricks_api_token: str | None = Field( + default=None, + description="API token for Databricks models", + ) + + @json_schema_type class DatabricksImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 200b36171..cf7c72924 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -21,6 +21,8 @@ logger = get_logger(name=__name__, category="inference::databricks") class DatabricksInferenceAdapter(OpenAIMixin): config: DatabricksImplConfig + provider_data_api_key_field: str = "databricks_api_token" + # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models embedding_model_metadata: dict[str, dict[str, int]] = { "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 4b310d770..df623934b 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,12 +7,19 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class NVIDIAProviderDataValidator(BaseModel): + nvidia_api_key: str | None = Field( + default=None, + description="API key for NVIDIA NIM models", + ) + + @json_schema_type class NVIDIAConfig(RemoteInferenceProviderConfig): """ diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7a2697327..d30d8b0e1 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -24,6 +24,8 @@ logger = get_logger(name=__name__, category="inference::nvidia") class NVIDIAInferenceAdapter(OpenAIMixin): config: NVIDIAConfig + provider_data_api_key_field: str = "nvidia_api_key" + """ NVIDIA Inference Adapter for Llama Stack. diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index cdfe0f885..93db2d0f5 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class RunpodProviderDataValidator(BaseModel): + runpod_api_token: str | None = Field( + default=None, + description="API token for RunPod models", + ) + + @json_schema_type class RunpodImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5..6d5968f82 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin): config: RunpodImplConfig + provider_data_api_key_field: str = "runpod_api_token" + def get_api_key(self) -> str: """Get API key for OpenAI client.""" return self.config.api_token diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index 55a6793c2..aa3a2c77a 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -10,47 +10,124 @@ from unittest.mock import MagicMock import pytest from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter +from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter +from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig +from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig +from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter +from llama_stack.providers.remote.inference.gemini.config import GeminiConfig +from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig +from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter +from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig +from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( - "config_cls,adapter_cls,provider_data_validator", + "config_cls,adapter_cls,provider_data_validator,config_params", [ ( GroqConfig, GroqInferenceAdapter, "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + {}, ), ( OpenAIConfig, OpenAIInferenceAdapter, "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + {}, ), ( TogetherImplConfig, TogetherInferenceAdapter, "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + {}, ), ( LlamaCompatConfig, LlamaCompatInferenceAdapter, "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + {}, + ), + ( + CerebrasImplConfig, + CerebrasInferenceAdapter, + "llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", + {}, + ), + ( + DatabricksImplConfig, + DatabricksInferenceAdapter, + "llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", + {}, + ), + ( + NVIDIAConfig, + NVIDIAInferenceAdapter, + "llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", + {}, + ), + ( + RunpodImplConfig, + RunpodInferenceAdapter, + "llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", + {}, + ), + ( + FireworksImplConfig, + FireworksInferenceAdapter, + "llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + {}, + ), + ( + AnthropicConfig, + AnthropicInferenceAdapter, + "llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + {}, + ), + ( + GeminiConfig, + GeminiInferenceAdapter, + "llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + {}, + ), + ( + SambaNovaImplConfig, + SambaNovaInferenceAdapter, + "llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + {}, + ), + ( + VLLMInferenceAdapterConfig, + VLLMInferenceAdapter, + "llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + { + "url": "http://fake", + }, ), ], ) -def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): +def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict): """Ensure the OpenAI provider does not cache api keys across client requests""" - - inference_adapter = adapter_cls(config=config_cls()) + inference_adapter = adapter_cls(config=config_cls(**config_params)) inference_adapter.__provider_spec__ = MagicMock() inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator