fix tests

This commit is contained in:
Swapna Lekkala 2025-10-14 13:44:30 -07:00
parent 9a4d3d7576
commit 6e028023f9
4 changed files with 75 additions and 99 deletions

View file

@ -73,7 +73,7 @@ from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
convert_chat_choice_to_response_message,
is_function_tool_call,
run_multiple_guardrails,
run_guardrails,
)
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -195,7 +195,7 @@ class StreamingResponseOrchestrator:
# Input safety validation - check messages before processing
if self.guardrail_ids:
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
if input_violation_message:
logger.info(f"Input guardrail violation: {input_violation_message}")
# Return refusal response immediately
@ -708,7 +708,7 @@ class StreamingResponseOrchestrator:
# Output Safety Validation for a chunk
if self.guardrail_ids:
accumulated_text = "".join(chat_response_content)
violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
if violation_message:
logger.info(f"Output guardrail violation: {violation_message}")
yield await self._create_refusal_response(violation_message)

View file

@ -313,9 +313,9 @@ def is_function_tool_call(
return False
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run multiple guardrails against messages and return violation message if blocked."""
if not guardrail_ids or not messages:
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run guardrails against messages and return violation message if blocked."""
if not messages:
return None
# Look up shields to get their provider_resource_id (actual model ID)