feat add sambanova safety adaptor

This commit is contained in:
jhpiedrahitao 2025-03-25 14:56:49 -05:00
parent 8783dd8162
commit ed0d5a4df1
10 changed files with 165 additions and 9 deletions

View file

@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::sambanova`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` |
| safety | `remote::sambanova` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
),
),
]

View file

@ -28,6 +28,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
self.registered_shields = []
self.test = 2
async def initialize(self) -> None:
try:

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SambaNovaSafetyConfig
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
from .sambanova import SambaNovaSafetyAdapter
impl = SambaNovaSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class SambaNovaSafetyConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
}

View file

@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List, Optional
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models"
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
available_models = [model.get("id") for model in response.json().get("data", {})]
if (
len(available_models) == 0
or "guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[Dict[str, Any]] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(model=shield.provider_resource_id, messages=content_messages)
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():
user_message = CANNED_RESPONSE_TEXT
violation_type = shield_message.split("\n")[-1]
metadata = {"violation_type": violation_type}
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()

View file

@ -1,6 +1,6 @@
version: '2'
distribution_spec:
description: Use SambaNova for running LLM inference
description: Use SambaNova for running LLM inference and safety
providers:
inference:
- remote::sambanova
@ -10,7 +10,7 @@ distribution_spec:
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
- remote::sambanova
agents:
- inline::meta-reference
telemetry:

View file

@ -38,10 +38,11 @@ providers:
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_id: sambanova
provider_type: remote::sambanova
config:
excluded_categories: []
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -181,6 +182,9 @@ models:
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
provider_shield_id: sambanova/Meta-Llama-Guard-3-8B
- shield_id: sambanova/Meta-Llama-Guard-3-8B
provider_shield_id: sambanova/Meta-Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::sambanova", "inline::sentence-transformers"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"safety": ["remote::sambanova"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
@ -115,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use SambaNova for running LLM inference",
description="Use SambaNova for running LLM inference and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
@ -127,7 +127,15 @@ def get_distribution_template() -> DistributionTemplate:
"vector_io": vector_io_providers,
},
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_shields=[
ShieldInput(
shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B"
),
ShieldInput(
shield_id="sambanova/Meta-Llama-Guard-3-8B",
provider_shield_id="sambanova/Meta-Llama-Guard-3-8B",
),
],
default_tool_groups=default_tool_groups,
),
},

View file

@ -255,6 +255,7 @@ exclude = [
"^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/sambanova/",
"^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/",
"^llama_stack/providers/remote/tool_runtime/bing_search/",