fix shields step

This commit is contained in:
Xi Yan 2024-12-16 20:59:10 -08:00
parent c2f7905fa4
commit bf961f8aa5

View file

@ -239,6 +239,8 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0:
print("input shields", self.input_shields)
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
@ -262,6 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
if len(self.output_shields) > 0:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
@ -279,6 +282,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
print("!!! HI run_multiple_shields_wrapper")
with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
@ -531,7 +535,6 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message]
else:
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
name = tool_call.tool_name
@ -597,39 +600,6 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=None,
),
)
)
)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=e.violation,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
return
if out_attachment := interpret_content_as_attachment(
result_message.content