mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 16:12:45 +00:00
fix tests
This commit is contained in:
parent
495d233007
commit
da07772480
148 changed files with 24750 additions and 12 deletions
|
|
@ -47,7 +47,9 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
from ..safety import SafetyException
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
|
@ -337,17 +339,15 @@ async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_i
|
|||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
if response.flagged:
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
|
||||
from ..safety import SafetyException
|
||||
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"categories": response.categories},
|
||||
)
|
||||
raise SafetyException(violation)
|
||||
# Check if any of the results are flagged
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"categories": result.categories},
|
||||
)
|
||||
raise SafetyException(violation)
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue