mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import pydantic
|
|
from together import Together
|
|
|
|
import asyncio
|
|
|
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
|
from .config import TogetherSafetyConfig
|
|
from llama_stack.apis.safety import *
|
|
import logging
|
|
|
|
class TogetherHeaderInfo(BaseModel):
|
|
together_api_key: str
|
|
|
|
|
|
class TogetherSafetyImpl(Safety):
|
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
|
self.config = config
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def run_shield(
|
|
self,
|
|
shield_type: str,
|
|
messages: List[Message],
|
|
params: Dict[str, Any] = None,
|
|
) -> RunShieldResponse:
|
|
# support only llama guard shield
|
|
|
|
if shield_type != "llama_guard":
|
|
raise ValueError(f"shield type {shield_type} is not supported")
|
|
|
|
provider_data = get_request_provider_data()
|
|
together_api_key = self.config.api_key
|
|
# @TODO error out if together_api_key is missing in the header
|
|
if provider_data is not None:
|
|
if isinstance(provider_data, TogetherHeaderInfo):
|
|
together_api_key = provider_data.together_api_key
|
|
|
|
# messages can have role assistant or user
|
|
api_messages = []
|
|
for message in messages:
|
|
if type(message) is UserMessage:
|
|
api_messages.append({'role': message.role, 'content': message.content})
|
|
else:
|
|
raise ValueError(f"role {message.role} is not supported")
|
|
|
|
# construct Together request
|
|
response = await asyncio.run(get_safety_response(together_api_key, api_messages))
|
|
return RunShieldResponse(violation=response)
|
|
|
|
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
|
|
client = Together(api_key=api_key)
|
|
response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B")
|
|
if len(response.choices) == 0:
|
|
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
|
|
|
response_text = response.choices[0].message.content
|
|
if response_text == 'safe':
|
|
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
|
else:
|
|
parts = response_text.split("\n")
|
|
if not len(parts) == 2:
|
|
return None
|
|
|
|
if parts[0] == 'unsafe':
|
|
SafetyViolation(violation_level=ViolationLevel.WARN, user_message="unsafe",
|
|
metadata={"violation_type": parts[1]})
|
|
|
|
return None
|
|
|
|
|
|
|
|
|