fix shields step

This commit is contained in:
Xi Yan 2024-12-16 20:59:10 -08:00
parent c2f7905fa4
commit bf961f8aa5

View file

@ -239,13 +239,15 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a # return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it. # final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_multiple_shields_wrapper( if len(self.input_shields) > 0:
turn_id, input_messages, self.input_shields, "user-input" print("input shields", self.input_shields)
): async for res in self.run_multiple_shields_wrapper(
if isinstance(res, bool): turn_id, input_messages, self.input_shields, "user-input"
return ):
else: if isinstance(res, bool):
yield res return
else:
yield res
async for res in self._run( async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream session_id, turn_id, input_messages, attachments, sampling_params, stream
@ -262,13 +264,14 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination # for output shields run on the full input and output combination
messages = input_messages + [final_response] messages = input_messages + [final_response]
async for res in self.run_multiple_shields_wrapper( if len(self.output_shields) > 0:
turn_id, messages, self.output_shields, "assistant-output" async for res in self.run_multiple_shields_wrapper(
): turn_id, messages, self.output_shields, "assistant-output"
if isinstance(res, bool): ):
return if isinstance(res, bool):
else: return
yield res else:
yield res
yield final_response yield final_response
@ -279,6 +282,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str], shields: List[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
print("!!! HI run_multiple_shields_wrapper")
with tracing.span("run_shields") as span: with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
@ -531,106 +535,72 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
log.info(f"{str(message)}") log.info(f"{str(message)}")
try: tool_call = message.tool_calls[0]
tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
if not isinstance(name, BuiltinTool): if not isinstance(name, BuiltinTool):
yield message yield message
return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
)
)
)
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
)
],
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=None,
),
)
)
)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=e.violation,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
return return
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
)
)
)
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
)
],
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := interpret_content_as_attachment( if out_attachment := interpret_content_as_attachment(
result_message.content result_message.content
): ):