[bugfix] no shield_call when there's no shields configured (#642)

# What does this PR do?

**Why**
- When AgentConfig has no `input_shields` / `output_shields` defined, we
still outputs a shield_call step with violation=None. This is impossible
to distinguish the case b/w (1) no violation from running shields v.s.
(2) no shields call

**What**
- We should not have a shield_call step when no `input_shields` /
`output_shields` are defined.

- Also removes a never reached try/catch code block in agent loop.
`run_multiple_shields` is never called in the try block (verified by
stacktrace print)

**Side Note**
- pre-commit fix

## Test Plan

Tested w/ DirectClient via:
https://gist.github.com/yanxi0830/b48f2a53b6f5391b9ff1e39992bc05b3

**No Shields**
<img width="858" alt="image"
src="https://github.com/user-attachments/assets/67319370-329f-4954-bd16-d21ce54c6ebf"
/>

**With Input + Output Shields**
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/75ab1bee-3ba9-4549-ab51-23210be83da7"
/>

**Input Shields Only**
<img width="858" alt="image"
src="https://github.com/user-attachments/assets/1897206b-13dd-4ea5-92c2-b39bf68e9286"
/>


E2E pytest
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk/agents/test_agents.py
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2024-12-17 11:10:19 -08:00 committed by GitHub
parent c2f7905fa4
commit 99f331f5c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 84 additions and 113 deletions

View file

@ -239,6 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a # return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it. # final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input" turn_id, input_messages, self.input_shields, "user-input"
): ):
@ -262,6 +263,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination # for output shields run on the full input and output combination
messages = input_messages + [final_response] messages = input_messages + [final_response]
if len(self.output_shields) > 0:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output" turn_id, messages, self.output_shields, "assistant-output"
): ):
@ -531,7 +533,6 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
log.info(f"{str(message)}") log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
@ -597,39 +598,6 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=None,
),
)
)
)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
violation=e.violation,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
return
if out_attachment := interpret_content_as_attachment( if out_attachment := interpret_content_as_attachment(
result_message.content result_message.content

View file

@ -7,6 +7,7 @@
from typing import * # noqa: F403 from typing import * # noqa: F403
import json import json
import uuid import uuid
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId

View file

@ -7,12 +7,14 @@
from pathlib import Path from pathlib import Path
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Provider from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.apis.models import ModelInput from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {