feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod (#3734)

# What does this PR do?

add provider-data key passing support to Cerebras, Databricks, NVIDIA
and RunPod

also, added missing tests for Fireworks, Anthropic, Gemini, SambaNova,
and vLLM

addresses #3517 

## Test Plan

ci w/ new tests

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Matthew Farrellee 2025-10-27 16:09:35 -04:00 committed by GitHub
parent 471b1b248b
commit a9b00db421
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 171 additions and 8 deletions

View file

@ -18,6 +18,8 @@ from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(

View file

@ -6,12 +6,19 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class DatabricksProviderDataValidator(BaseModel):
databricks_api_token: str | None = Field(
default=None,
description="API token for Databricks models",
)
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(

View file

@ -20,6 +20,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
provider_data_api_key_field: str = "databricks_api_token"
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},

View file

@ -7,12 +7,19 @@
import os
from typing import Any
from pydantic import Field
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class NVIDIAProviderDataValidator(BaseModel):
nvidia_api_key: str | None = Field(
default=None,
description="API key for NVIDIA NIM models",
)
@json_schema_type
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""

View file

@ -17,6 +17,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
provider_data_api_key_field: str = "nvidia_api_key"
"""
NVIDIA Inference Adapter for Llama Stack.
"""

View file

@ -6,12 +6,19 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class RunpodProviderDataValidator(BaseModel):
runpod_api_token: str | None = Field(
default=None,
description="API token for RunPod models",
)
@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(

View file

@ -24,6 +24,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
"""
config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token"
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""