mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Push registration methods onto the backing providers
This commit is contained in:
parent
5a7b01d292
commit
4215cc9331
14 changed files with 269 additions and 220 deletions
|
|
@ -6,28 +6,23 @@
|
|||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
SAFETY_SHIELD_MODEL_MAP = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.register_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -35,16 +30,31 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.llama_guard.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
|
||||
self.registered_shields.append(shield)
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return self.registered_shields
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
|
||||
for shield in self.registered_shields:
|
||||
if shield.identifier == identifier:
|
||||
return shield
|
||||
return None
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
shield_def = await self.get_shield(shield_type)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
|
||||
model = shield_def.params.get("model", "llama_guard")
|
||||
if model not in SAFETY_SHIELD_MODEL_MAP:
|
||||
raise ValueError(f"Unsupported safety model: {model}")
|
||||
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
|
|
@ -57,17 +67,13 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if message.role in (Role.user.value, Role.assistant.value):
|
||||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, model_name, api_messages
|
||||
)
|
||||
violation = await get_safety_response(together_api_key, model, api_messages)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue