mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod
- added missing tests for Fireworks, Anthropic, Gemini, SambaNova, and vLLM
This commit is contained in:
parent
5d711d4bcb
commit
bb95c1a7c5
10 changed files with 125 additions and 8 deletions
|
|
@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.cerebras",
|
module="llama_stack.providers.remote.inference.cerebras",
|
||||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
|
|
@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
pip_packages=["databricks-sdk"],
|
pip_packages=["databricks-sdk"],
|
||||||
module="llama_stack.providers.remote.inference.databricks",
|
module="llama_stack.providers.remote.inference.databricks",
|
||||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
|
|
@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.nvidia",
|
module="llama_stack.providers.remote.inference.nvidia",
|
||||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
|
|
@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.runpod",
|
module="llama_stack.providers.remote.inference.runpod",
|
||||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||||
),
|
),
|
||||||
RemoteProviderSpec(
|
RemoteProviderSpec(
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ from .config import CerebrasImplConfig
|
||||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||||
config: CerebrasImplConfig
|
config: CerebrasImplConfig
|
||||||
|
|
||||||
|
provider_data_api_key_field: str = "cerebras_api_key"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
return self.config.api_key.get_secret_value()
|
return self.config.api_key.get_secret_value()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||||
|
|
||||||
|
|
||||||
|
class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
cerebras_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Cerebras models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||||
base_url: str = Field(
|
base_url: str = Field(
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,19 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksProviderDataValidator(BaseModel):
|
||||||
|
databricks_api_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API token for Databricks models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
url: str | None = Field(
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
|
||||||
class DatabricksInferenceAdapter(OpenAIMixin):
|
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
config: DatabricksImplConfig
|
config: DatabricksImplConfig
|
||||||
|
|
||||||
|
provider_data_api_key_field: str = "databricks_api_token"
|
||||||
|
|
||||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,19 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAProviderDataValidator(BaseModel):
|
||||||
|
nvidia_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for NVIDIA NIM models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
config: NVIDIAConfig
|
config: NVIDIAConfig
|
||||||
|
|
||||||
|
provider_data_api_key_field: str = "nvidia_api_key"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
NVIDIA Inference Adapter for Llama Stack.
|
NVIDIA Inference Adapter for Llama Stack.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,19 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class RunpodProviderDataValidator(BaseModel):
|
||||||
|
runpod_api_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API token for RunPod models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
url: str | None = Field(
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
config: RunpodImplConfig
|
config: RunpodImplConfig
|
||||||
|
|
||||||
|
provider_data_api_key_field: str = "runpod_api_token"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
"""Get API key for OpenAI client."""
|
"""Get API key for OpenAI client."""
|
||||||
return self.config.api_token
|
return self.config.api_token
|
||||||
|
|
|
||||||
|
|
@ -10,47 +10,124 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.request_headers import request_provider_data_context
|
from llama_stack.core.request_headers import request_provider_data_context
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
|
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
|
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||||
|
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config_cls,adapter_cls,provider_data_validator",
|
"config_cls,adapter_cls,provider_data_validator,config_params",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
GroqConfig,
|
GroqConfig,
|
||||||
GroqInferenceAdapter,
|
GroqInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
OpenAIInferenceAdapter,
|
OpenAIInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||||
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TogetherImplConfig,
|
TogetherImplConfig,
|
||||||
TogetherInferenceAdapter,
|
TogetherInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||||
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
LlamaCompatConfig,
|
LlamaCompatConfig,
|
||||||
LlamaCompatInferenceAdapter,
|
LlamaCompatInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
CerebrasImplConfig,
|
||||||
|
CerebrasInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
DatabricksImplConfig,
|
||||||
|
DatabricksInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
NVIDIAConfig,
|
||||||
|
NVIDIAInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
RunpodImplConfig,
|
||||||
|
RunpodInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
FireworksImplConfig,
|
||||||
|
FireworksInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
AnthropicConfig,
|
||||||
|
AnthropicInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GeminiConfig,
|
||||||
|
GeminiInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
SambaNovaImplConfig,
|
||||||
|
SambaNovaInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
VLLMInferenceAdapterConfig,
|
||||||
|
VLLMInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
|
{
|
||||||
|
"url": "http://fake",
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
|
||||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||||
|
inference_adapter = adapter_cls(config=config_cls(**config_params))
|
||||||
inference_adapter = adapter_cls(config=config_cls())
|
|
||||||
|
|
||||||
inference_adapter.__provider_spec__ = MagicMock()
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue