Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_MODEL_MAP = {
TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = {
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
self.register_shields = []
async def initialize(self) -> None:
pass
@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if shield.type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in SAFETY_SHIELD_MODEL_MAP:
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None
@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, model, api_messages)
violation = await get_safety_response(
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation)