From 633bb9c5b3273762ccf0d91845243bb11166066d Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Wed, 21 May 2025 17:33:02 -0500 Subject: [PATCH] feat(providers): sambanova safety provider (#2221) # What does this PR do? Includes SambaNova safety adaptor to use the sambanova cloud served Meta-Llama-Guard-3-8B minor updates in sambanova docs ## Test Plan pytest -s -v tests/integration/safety/test_safety.py --stack-config=sambanova --safety-shield=sambanova/Meta-Llama-Guard-3-8B --- README.md | 2 +- .../self_hosted_distro/sambanova.md | 25 +++-- llama_stack/providers/registry/safety.py | 10 ++ .../remote/safety/sambanova/__init__.py | 18 ++++ .../remote/safety/sambanova/config.py | 37 +++++++ .../remote/safety/sambanova/sambanova.py | 100 ++++++++++++++++++ llama_stack/templates/sambanova/build.yaml | 4 +- .../templates/sambanova/doc_template.md | 23 ++-- llama_stack/templates/sambanova/run.yaml | 10 +- llama_stack/templates/sambanova/sambanova.py | 14 ++- pyproject.toml | 1 + 11 files changed, 222 insertions(+), 22 deletions(-) create mode 100644 llama_stack/providers/remote/safety/sambanova/__init__.py create mode 100644 llama_stack/providers/remote/safety/sambanova/config.py create mode 100644 llama_stack/providers/remote/safety/sambanova/sambanova.py diff --git a/README.md b/README.md index 5dfe3577a..e54b505cf 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Here is a list of the various API providers and available distributions that can | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| | Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | -| SambaNova | Hosted | | ✅ | | | | +| SambaNova | Hosted | | ✅ | | ✅ | | | Cerebras | Hosted | | ✅ | | | | | Fireworks | Hosted | ✅ | ✅ | ✅ | | | | AWS Bedrock | Hosted | | ✅ | | ✅ | | diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index aaa8fd3cc..bb4842362 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p |-----|-------------| | agents | `inline::meta-reference` | | inference | `remote::sambanova`, `inline::sentence-transformers` | -| safety | `inline::llama-guard` | +| safety | `remote::sambanova` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -48,33 +48,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-sambanova \ + -v ~/.llama:/root/.llama \ + distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index c209da092..e0a04be48 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="sambanova", + pip_packages=["litellm"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/safety/sambanova/__init__.py b/llama_stack/providers/remote/safety/sambanova/__init__.py new file mode 100644 index 000000000..bb9d15374 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import SambaNovaSafetyConfig + + +async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: + from .sambanova import SambaNovaSafetyAdapter + + impl = SambaNovaSafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/sambanova/config.py b/llama_stack/providers/remote/safety/sambanova/config.py new file mode 100644 index 000000000..383cea244 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/config.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: str | None = Field( + default=None, + description="Sambanova Cloud API key", + ) + + +@json_schema_type +class SambaNovaSafetyConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the SambaNova AI server", + ) + api_key: SecretStr | None = Field( + default=None, + description="The SambaNova cloud API Key", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: + return { + "url": "https://api.sambanova.ai/v1", + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py new file mode 100644 index 000000000..84c8267ae --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from typing import Any + +import litellm +import requests + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new + +from .config import SambaNovaSafetyConfig + +logger = logging.getLogger(__name__) + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" + + +class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): + def __init__(self, config: SambaNovaSafetyConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def _get_api_key(self) -> str: + config_api_key = self.config.api_key if self.config.api_key else None + if config_api_key: + return config_api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.sambanova_api_key: + raise ValueError( + 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' + ) + return provider_data.sambanova_api_key + + async def register_shield(self, shield: Shield) -> None: + list_models_url = self.config.url + "/models" + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + available_models = [model.get("id") for model in response.json().get("data", {})] + if ( + len(available_models) == 0 + or "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + ): + raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") + + async def run_shield( + self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + + shield_params = shield.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") + content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = litellm.completion( + model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() + ) + shield_message = response.choices[0].message.content + + if "unsafe" in shield_message.lower(): + user_message = CANNED_RESPONSE_TEXT + violation_type = shield_message.split("\n")[-1] + metadata = {"violation_type": violation_type} + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + + return RunShieldResponse() diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 81d90f420..79bb68c68 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use SambaNova for running LLM inference + description: Use SambaNova for running LLM inference and safety providers: inference: - remote::sambanova @@ -10,7 +10,7 @@ distribution_spec: - remote::chromadb - remote::pgvector safety: - - inline::llama-guard + - remote::sambanova agents: - inline::meta-reference telemetry: diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md index 42d9efb66..1dc76fd3f 100644 --- a/llama_stack/templates/sambanova/doc_template.md +++ b/llama_stack/templates/sambanova/doc_template.md @@ -37,33 +37,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-{{ name }} \ + -v ~/.llama:/root/.llama \ + distribution-{{ name }} \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 620d50307..fa8735002 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -38,10 +38,11 @@ providers: user: ${env.PGVECTOR_USER:} password: ${env.PGVECTOR_PASSWORD:} safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_id: sambanova + provider_type: remote::sambanova config: - excluded_categories: [] + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -189,6 +190,9 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B +- shield_id: sambanova/Meta-Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 2f8a0b08a..54a49423d 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], + "safety": ["remote::sambanova"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], "tool_runtime": [ @@ -110,7 +110,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova for running LLM inference", + description="Use SambaNova for running LLM inference and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -122,7 +122,15 @@ def get_distribution_template() -> DistributionTemplate: "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput( + shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B" + ), + ShieldInput( + shield_id="sambanova/Meta-Llama-Guard-3-8B", + provider_shield_id="sambanova/Meta-Llama-Guard-3-8B", + ), + ], default_tool_groups=default_tool_groups, ), }, diff --git a/pyproject.toml b/pyproject.toml index ce44479ca..6b873968a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -270,6 +270,7 @@ exclude = [ "^llama_stack/providers/remote/inference/watsonx/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", + "^llama_stack/providers/remote/safety/sambanova/", "^llama_stack/providers/remote/safety/sample/", "^llama_stack/providers/remote/tool_runtime/bing_search/", "^llama_stack/providers/remote/tool_runtime/brave_search/",