mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 22:12:40 +00:00
fix tests
This commit is contained in:
parent
9a4d3d7576
commit
6e028023f9
4 changed files with 75 additions and 99 deletions
|
|
@ -73,7 +73,7 @@ from .types import ChatCompletionContext, ChatCompletionResult
|
|||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_multiple_guardrails,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
|
@ -195,7 +195,7 @@ class StreamingResponseOrchestrator:
|
|||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
if input_violation_message:
|
||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||
# Return refusal response immediately
|
||||
|
|
@ -708,7 +708,7 @@ class StreamingResponseOrchestrator:
|
|||
# Output Safety Validation for a chunk
|
||||
if self.guardrail_ids:
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
|
|
|
|||
|
|
@ -313,9 +313,9 @@ def is_function_tool_call(
|
|||
return False
|
||||
|
||||
|
||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run multiple guardrails against messages and return violation message if blocked."""
|
||||
if not guardrail_ids or not messages:
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue