mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: add provider data keys for Cerebras, Databricks, NVIDIA, and RunPod (#3734)
# What does this PR do? add provider-data key passing support to Cerebras, Databricks, NVIDIA and RunPod also, added missing tests for Fireworks, Anthropic, Gemini, SambaNova, and vLLM addresses #3517 ## Test Plan ci w/ new tests --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
471b1b248b
commit
a9b00db421
12 changed files with 171 additions and 8 deletions
|
|
@ -61,6 +61,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -149,6 +150,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -158,6 +160,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -167,6 +170,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from .config import CerebrasImplConfig
|
|||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "cerebras_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
|
|||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DatabricksProviderDataValidator(BaseModel):
|
||||
databricks_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for Databricks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
|
|||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "databricks_api_token"
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
|
|
|
|||
|
|
@ -7,12 +7,19 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class NVIDIAProviderDataValidator(BaseModel):
|
||||
nvidia_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for NVIDIA NIM models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
|
|||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
provider_data_api_key_field: str = "nvidia_api_key"
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class RunpodProviderDataValidator(BaseModel):
|
||||
runpod_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for RunPod models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
"""
|
||||
|
||||
config: RunpodImplConfig
|
||||
provider_data_api_key_field: str = "runpod_api_token"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue