mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod
- added missing tests for Fireworks, Anthropic, Gemini, SambaNova, and vLLM
This commit is contained in:
parent
5d711d4bcb
commit
bb95c1a7c5
10 changed files with 125 additions and 8 deletions
|
|
@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from .config import CerebrasImplConfig
|
|||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "cerebras_api_key"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
|
|||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DatabricksProviderDataValidator(BaseModel):
|
||||
databricks_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for Databricks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
|
|||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "databricks_api_token"
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
|
|
|
|||
|
|
@ -7,12 +7,19 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class NVIDIAProviderDataValidator(BaseModel):
|
||||
nvidia_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for NVIDIA NIM models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
|
|||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
provider_data_api_key_field: str = "nvidia_api_key"
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class RunpodProviderDataValidator(BaseModel):
|
||||
runpod_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for RunPod models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
|
||||
config: RunpodImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "runpod_api_token"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
|
|
|||
|
|
@ -10,47 +10,124 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
"config_cls,adapter_cls,provider_data_validator,config_params",
|
||||
[
|
||||
(
|
||||
GroqConfig,
|
||||
GroqInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
OpenAIConfig,
|
||||
OpenAIInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
TogetherImplConfig,
|
||||
TogetherInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
LlamaCompatConfig,
|
||||
LlamaCompatInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
CerebrasImplConfig,
|
||||
CerebrasInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
DatabricksImplConfig,
|
||||
DatabricksInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
NVIDIAConfig,
|
||||
NVIDIAInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
RunpodImplConfig,
|
||||
RunpodInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
FireworksImplConfig,
|
||||
FireworksInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
AnthropicConfig,
|
||||
AnthropicInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
GeminiConfig,
|
||||
GeminiInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
SambaNovaImplConfig,
|
||||
SambaNovaInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
VLLMInferenceAdapterConfig,
|
||||
VLLMInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
{
|
||||
"url": "http://fake",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
|
||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
inference_adapter = adapter_cls(config=config_cls(**config_params))
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue