support json format

This commit is contained in:
Hardik Shah 2024-08-14 12:43:43 -07:00
parent 48b78430eb
commit 86df597a83
7 changed files with 97 additions and 29 deletions

View file

@ -6,14 +6,15 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import (
OnViolationAction,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(