mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
test safety against safety client
This commit is contained in:
parent
d6a41d98d2
commit
9252e81a7b
19 changed files with 1076 additions and 10754 deletions
|
@ -211,7 +211,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
|
@ -234,7 +234,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
|
@ -244,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
async def run_shields_wrapper(
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: List[Message],
|
||||
|
@ -265,7 +265,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
await self.run_shields(messages, shields)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
|
@ -31,7 +31,9 @@ class ShieldRunnerMixin:
|
|||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_shields(self, messages: List[Message], shields: List[str]) -> None:
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shields: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
|
|
|
@ -78,7 +78,7 @@ class MockInferenceAPI:
|
|||
|
||||
|
||||
class MockSafetyAPI:
|
||||
async def run_shields(
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
@ -220,13 +220,13 @@ async def test_chat_agent_create_and_execute_turn(chat_agent):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_wrapper(chat_agent):
|
||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
async for chunk in chat_agent.run_shields_wrapper(
|
||||
async for chunk in chat_agent.run_multiple_shields_wrapper(
|
||||
turn_id="test_turn_id",
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
|
|
|
@ -34,11 +34,11 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
if self.input_shields:
|
||||
await self.run_shields(messages, self.input_shields)
|
||||
await self.run_multiple_shields(messages, self.input_shields)
|
||||
# run the underlying tool
|
||||
res = await self._tool.run(messages)
|
||||
if self.output_shields:
|
||||
await self.run_shields(messages, self.output_shields)
|
||||
await self.run_multiple_shields(messages, self.output_shields)
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ from .base import ( # noqa: F401
|
|||
TextShield,
|
||||
)
|
||||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import ( # noqa: F401
|
||||
InjectionShield,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue