forked from phoenix-oss/llama-stack-mirror
		
	feat(providers): sambanova safety provider (#2221)
# What does this PR do? Includes SambaNova safety adaptor to use the sambanova cloud served Meta-Llama-Guard-3-8B minor updates in sambanova docs ## Test Plan pytest -s -v tests/integration/safety/test_safety.py --stack-config=sambanova --safety-shield=sambanova/Meta-Llama-Guard-3-8B
This commit is contained in:
		
							parent
							
								
									02e5e8a633
								
							
						
					
					
						commit
						633bb9c5b3
					
				
					 11 changed files with 222 additions and 22 deletions
				
			
		
							
								
								
									
										18
									
								
								llama_stack/providers/remote/safety/sambanova/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								llama_stack/providers/remote/safety/sambanova/__init__.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,18 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| 
 | ||||
| from typing import Any | ||||
| 
 | ||||
| from .config import SambaNovaSafetyConfig | ||||
| 
 | ||||
| 
 | ||||
| async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: | ||||
|     from .sambanova import SambaNovaSafetyAdapter | ||||
| 
 | ||||
|     impl = SambaNovaSafetyAdapter(config) | ||||
|     await impl.initialize() | ||||
|     return impl | ||||
							
								
								
									
										37
									
								
								llama_stack/providers/remote/safety/sambanova/config.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								llama_stack/providers/remote/safety/sambanova/config.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,37 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from typing import Any | ||||
| 
 | ||||
| from pydantic import BaseModel, Field, SecretStr | ||||
| 
 | ||||
| from llama_stack.schema_utils import json_schema_type | ||||
| 
 | ||||
| 
 | ||||
| class SambaNovaProviderDataValidator(BaseModel): | ||||
|     sambanova_api_key: str | None = Field( | ||||
|         default=None, | ||||
|         description="Sambanova Cloud API key", | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @json_schema_type | ||||
| class SambaNovaSafetyConfig(BaseModel): | ||||
|     url: str = Field( | ||||
|         default="https://api.sambanova.ai/v1", | ||||
|         description="The URL for the SambaNova AI server", | ||||
|     ) | ||||
|     api_key: SecretStr | None = Field( | ||||
|         default=None, | ||||
|         description="The SambaNova cloud API Key", | ||||
|     ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: | ||||
|         return { | ||||
|             "url": "https://api.sambanova.ai/v1", | ||||
|             "api_key": api_key, | ||||
|         } | ||||
							
								
								
									
										100
									
								
								llama_stack/providers/remote/safety/sambanova/sambanova.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								llama_stack/providers/remote/safety/sambanova/sambanova.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,100 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| from typing import Any | ||||
| 
 | ||||
| import litellm | ||||
| import requests | ||||
| 
 | ||||
| from llama_stack.apis.inference import Message | ||||
| from llama_stack.apis.safety import ( | ||||
|     RunShieldResponse, | ||||
|     Safety, | ||||
|     SafetyViolation, | ||||
|     ViolationLevel, | ||||
| ) | ||||
| from llama_stack.apis.shields import Shield | ||||
| from llama_stack.distribution.request_headers import NeedsRequestProviderData | ||||
| from llama_stack.providers.datatypes import ShieldsProtocolPrivate | ||||
| from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new | ||||
| 
 | ||||
| from .config import SambaNovaSafetyConfig | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" | ||||
| 
 | ||||
| 
 | ||||
| class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): | ||||
|     def __init__(self, config: SambaNovaSafetyConfig) -> None: | ||||
|         self.config = config | ||||
| 
 | ||||
|     async def initialize(self) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     async def shutdown(self) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     def _get_api_key(self) -> str: | ||||
|         config_api_key = self.config.api_key if self.config.api_key else None | ||||
|         if config_api_key: | ||||
|             return config_api_key.get_secret_value() | ||||
|         else: | ||||
|             provider_data = self.get_request_provider_data() | ||||
|             if provider_data is None or not provider_data.sambanova_api_key: | ||||
|                 raise ValueError( | ||||
|                     'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }' | ||||
|                 ) | ||||
|             return provider_data.sambanova_api_key | ||||
| 
 | ||||
|     async def register_shield(self, shield: Shield) -> None: | ||||
|         list_models_url = self.config.url + "/models" | ||||
|         try: | ||||
|             response = requests.get(list_models_url) | ||||
|             response.raise_for_status() | ||||
|         except requests.exceptions.RequestException as e: | ||||
|             raise RuntimeError(f"Request to {list_models_url} failed") from e | ||||
|         available_models = [model.get("id") for model in response.json().get("data", {})] | ||||
|         if ( | ||||
|             len(available_models) == 0 | ||||
|             or "guard" not in shield.provider_resource_id.lower() | ||||
|             or shield.provider_resource_id.split("sambanova/")[-1] not in available_models | ||||
|         ): | ||||
|             raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") | ||||
| 
 | ||||
|     async def run_shield( | ||||
|         self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None | ||||
|     ) -> RunShieldResponse: | ||||
|         shield = await self.shield_store.get_shield(shield_id) | ||||
|         if not shield: | ||||
|             raise ValueError(f"Shield {shield_id} not found") | ||||
| 
 | ||||
|         shield_params = shield.params | ||||
|         logger.debug(f"run_shield::{shield_params}::messages={messages}") | ||||
|         content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] | ||||
|         logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") | ||||
| 
 | ||||
|         response = litellm.completion( | ||||
|             model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() | ||||
|         ) | ||||
|         shield_message = response.choices[0].message.content | ||||
| 
 | ||||
|         if "unsafe" in shield_message.lower(): | ||||
|             user_message = CANNED_RESPONSE_TEXT | ||||
|             violation_type = shield_message.split("\n")[-1] | ||||
|             metadata = {"violation_type": violation_type} | ||||
| 
 | ||||
|             return RunShieldResponse( | ||||
|                 violation=SafetyViolation( | ||||
|                     user_message=user_message, | ||||
|                     violation_level=ViolationLevel.ERROR, | ||||
|                     metadata=metadata, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         return RunShieldResponse() | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue