llama-stack-mirror/llama_stack/providers/adapters/safety/together/safety.py
2024-09-24 00:18:24 -07:00

83 lines
2.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from together import Together
import asyncio
from .config import TogetherSafetyConfig
from llama_stack.apis.safety import *
import logging
class TogetherSafetyImpl(Safety):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
self._client = None
@property
def client(self) -> Together:
if self._client == None:
self._client = Together(api_key=self.config.api_key)
return self._client
@client.setter
def client(self, client: Together) -> None:
self._client = client
async def initialize(self) -> None:
pass
async def run_shields(
self,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
# support only llama guard shield
for shield in shields:
if not isinstance(shield.shield_type, BuiltinShield) or shield.shield_type != BuiltinShield.llama_guard:
raise ValueError(f"shield type {shield.shield_type} is not supported")
# messages can have role assistant or user
api_messages = []
for message in messages:
if type(message) is UserMessage:
api_messages.append({'role': message.role, 'content': message.content})
else:
raise ValueError(f"role {message.role} is not supported")
# construct Together request
responses = await asyncio.gather(*[get_safety_response(self.client, api_messages)])
return RunShieldResponse(responses=responses)
async def get_safety_response(client: Together, messages: List[Dict[str, str]]) -> Optional[ShieldResponse]:
response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B")
if len(response.choices) == 0:
return ShieldResponse(shield_type=BuiltinShield.llama_guard, is_violation=False)
response_text = response.choices[0].message.content
if response_text == 'safe':
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
parts = response_text.split("\n")
if not len(parts) == 2:
return None
if parts[0] == 'unsafe':
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=parts[1],
violation_return_message="Sorry, I cannot do that"
)
return None