diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index f7148ddce..7363fa0b1 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -7,15 +7,14 @@ from typing import List from llama_models.llama3.api.datatypes import Message, Role, UserMessage +from termcolor import cprint from llama_stack.apis.safety import ( OnViolationAction, - RunShieldRequest, Safety, ShieldDefinition, ShieldResponse, ) -from termcolor import cprint class SafetyException(Exception): # noqa: N818 @@ -45,10 +44,8 @@ class ShieldRunnerMixin: messages[0] = UserMessage(content=messages[0].content) res = await self.safety_api.run_shields( - RunShieldRequest( - messages=messages, - shields=shields, - ) + messages=messages, + shields=shields, ) results = res.responses