mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
Update safety implementation inside agents
This commit is contained in:
parent
82ddd851c8
commit
d6a41d98d2
8 changed files with 17 additions and 67 deletions
|
@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
content=violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
response=e.response,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
response=ShieldResponse(
|
||||
# TODO: fix this, give each shield a shield type method and
|
||||
# fire one event for each shield run
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
),
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
response=ShieldResponse(
|
||||
# TODO: fix this, give each shield a shield type method and
|
||||
# fire one event for each shield run
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
),
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
response=e.response,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -80,9 +79,9 @@ class MockInferenceAPI:
|
|||
|
||||
class MockSafetyAPI:
|
||||
async def run_shields(
|
||||
self, messages: List[Message], shields: List[MagicMock]
|
||||
) -> List[ShieldResponse]:
|
||||
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
|
||||
self, shield_type: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
||||
class MockMemoryAPI:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue