From ed0d5a4df15e2de58844a13cbafdb6d1e7a4fe4a Mon Sep 17 00:00:00 2001 From: jhpiedrahitao Date: Tue, 25 Mar 2025 14:56:49 -0500 Subject: [PATCH] feat add sambanova safety adaptor --- .../self_hosted_distro/sambanova.md | 2 +- llama_stack/providers/registry/safety.py | 9 ++ .../remote/safety/bedrock/bedrock.py | 1 + .../remote/safety/sambanova/__init__.py | 18 ++++ .../remote/safety/sambanova/config.py | 30 +++++++ .../remote/safety/sambanova/sambanova.py | 85 +++++++++++++++++++ llama_stack/templates/sambanova/build.yaml | 4 +- llama_stack/templates/sambanova/run.yaml | 10 ++- llama_stack/templates/sambanova/sambanova.py | 14 ++- pyproject.toml | 1 + 10 files changed, 165 insertions(+), 9 deletions(-) create mode 100644 llama_stack/providers/remote/safety/sambanova/__init__.py create mode 100644 llama_stack/providers/remote/safety/sambanova/config.py create mode 100644 llama_stack/providers/remote/safety/sambanova/sambanova.py diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index bd102dae0..aa8d98bff 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p |-----|-------------| | agents | `inline::meta-reference` | | inference | `remote::sambanova`, `inline::sentence-transformers` | -| safety | `inline::llama-guard` | +| safety | `remote::sambanova` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 54dc51034..0f48748b6 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="sambanova", + pip_packages=["litellm"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 2f960eead..3d2e70ff7 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -28,6 +28,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config self.registered_shields = [] + self.test = 2 async def initialize(self) -> None: try: diff --git a/llama_stack/providers/remote/safety/sambanova/__init__.py b/llama_stack/providers/remote/safety/sambanova/__init__.py new file mode 100644 index 000000000..bb9d15374 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import SambaNovaSafetyConfig + + +async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: + from .sambanova import SambaNovaSafetyAdapter + + impl = SambaNovaSafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/sambanova/config.py b/llama_stack/providers/remote/safety/sambanova/config.py new file mode 100644 index 000000000..eb4a521e6 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/config.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class SambaNovaSafetyConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the SambaNova AI server", + ) + api_key: Optional[str] = Field( + default=None, + description="The SambaNova cloud API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "https://api.sambanova.ai/v1", + "api_key": "${env.SAMBANOVA_API_KEY}", + } diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py new file mode 100644 index 000000000..7463fbf13 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from typing import Any, Dict, List, Optional + +import litellm +import requests + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new + +from .config import SambaNovaSafetyConfig + +logger = logging.getLogger(__name__) + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" + + +class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate): + def __init__(self, config: SambaNovaSafetyConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: Shield) -> None: + list_models_url = self.config.url + "/models" + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + available_models = [model.get("id") for model in response.json().get("data", {})] + if ( + len(available_models) == 0 + or "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + ): + raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") + + async def run_shield( + self, shield_id: str, messages: List[Message], params: Optional[Dict[str, Any]] = None + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + + shield_params = shield.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") + content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = litellm.completion(model=shield.provider_resource_id, messages=content_messages) + shield_message = response.choices[0].message.content + + if "unsafe" in shield_message.lower(): + user_message = CANNED_RESPONSE_TEXT + violation_type = shield_message.split("\n")[-1] + metadata = {"violation_type": violation_type} + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + + return RunShieldResponse() diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 9b09b9e90..25bed77ce 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use SambaNova for running LLM inference + description: Use SambaNova for running LLM inference and safety providers: inference: - remote::sambanova @@ -10,7 +10,7 @@ distribution_spec: - remote::chromadb - remote::pgvector safety: - - inline::llama-guard + - remote::sambanova agents: - inline::meta-reference telemetry: diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 080c8b8b3..d0ba1fb0d 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -38,10 +38,11 @@ providers: user: ${env.PGVECTOR_USER:} password: ${env.PGVECTOR_PASSWORD:} safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_id: sambanova + provider_type: remote::sambanova config: - excluded_categories: [] + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -181,6 +182,9 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B +- shield_id: sambanova/Meta-Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index d48eebac1..1c42c6b44 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], + "safety": ["remote::sambanova"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], "tool_runtime": [ @@ -115,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova for running LLM inference", + description="Use SambaNova for running LLM inference and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -127,7 +127,15 @@ def get_distribution_template() -> DistributionTemplate: "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput( + shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B" + ), + ShieldInput( + shield_id="sambanova/Meta-Llama-Guard-3-8B", + provider_shield_id="sambanova/Meta-Llama-Guard-3-8B", + ), + ], default_tool_groups=default_tool_groups, ), }, diff --git a/pyproject.toml b/pyproject.toml index f7125b724..7c741d77c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ exclude = [ "^llama_stack/providers/remote/inference/together/", "^llama_stack/providers/remote/inference/vllm/", "^llama_stack/providers/remote/safety/bedrock/", + "^llama_stack/providers/remote/safety/sambanova/", "^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/sample/", "^llama_stack/providers/remote/tool_runtime/bing_search/",