fix(bedrock): wrong system prompt transformation (#10120)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
Helm unit test / unit-test (push) Successful in 25s

* fix(bedrock): wrong system transformation

* chore: add one more test case

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Li Yang 2025-04-21 23:48:14 +08:00 committed by GitHub
parent 0b63c7a2eb
commit 10257426a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 95 additions and 13 deletions

View file

@ -376,25 +376,27 @@ class AmazonConverseConfig(BaseConfig):
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block: Optional[SystemContentBlock] = None
_cache_point_block: Optional[SystemContentBlock] = None
if isinstance(message["content"], str) and len(message["content"]) > 0:
_system_content_block = SystemContentBlock(text=message["content"])
_cache_point_block = self._get_cache_point_block(
system_prompt_indices.append(idx)
if isinstance(message["content"], str) and message["content"]:
system_content_blocks.append(
SystemContentBlock(text=message["content"])
)
cache_block = self._get_cache_point_block(
message, block_type="system"
)
if cache_block:
system_content_blocks.append(cache_block)
elif isinstance(message["content"], list):
for m in message["content"]:
if m.get("type", "") == "text" and len(m["text"]) > 0:
_system_content_block = SystemContentBlock(text=m["text"])
_cache_point_block = self._get_cache_point_block(
if m.get("type") == "text" and m.get("text"):
system_content_blocks.append(
SystemContentBlock(text=m["text"])
)
cache_block = self._get_cache_point_block(
m, block_type="system"
)
if _system_content_block is not None:
system_content_blocks.append(_system_content_block)
if _cache_point_block is not None:
system_content_blocks.append(_cache_point_block)
system_prompt_indices.append(idx)
if cache_block:
system_content_blocks.append(cache_block)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)

View file

@ -41,6 +41,85 @@ def test_transform_usage():
assert openai_usage._cache_creation_input_tokens == usage["cacheWriteInputTokens"]
assert openai_usage._cache_read_input_tokens == usage["cacheReadInputTokens"]
def test_transform_system_message():
config = AmazonConverseConfig()
# Case 1:
# System message popped
# User message remains
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
]
out_messages, system_blocks = config._transform_system_message(messages.copy())
assert len(out_messages) == 1
assert out_messages[0]["role"] == "user"
assert len(system_blocks) == 1
assert system_blocks[0]["text"] == "You are a helpful assistant."
# Case 2: System message with list content (type text)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "System prompt 1"},
{"type": "text", "text": "System prompt 2"},
],
},
{"role": "user", "content": "Hi!"},
]
out_messages, system_blocks = config._transform_system_message(messages.copy())
assert len(out_messages) == 1
assert out_messages[0]["role"] == "user"
assert len(system_blocks) == 2
assert system_blocks[0]["text"] == "System prompt 1"
assert system_blocks[1]["text"] == "System prompt 2"
# Case 3: System message with cache_control (should add cachePoint)
messages = [
{
"role": "system",
"content": "Cache this!",
"cache_control": {"type": "ephemeral"},
},
{"role": "user", "content": "Hi!"},
]
out_messages, system_blocks = config._transform_system_message(messages.copy())
assert len(out_messages) == 1
assert len(system_blocks) == 2
assert system_blocks[0]["text"] == "Cache this!"
assert "cachePoint" in system_blocks[1]
assert system_blocks[1]["cachePoint"]["type"] == "default"
# Case 3b: System message with two blocks, one with cache_control and one without
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "Cache this!", "cache_control": {"type": "ephemeral"}},
{"type": "text", "text": "Don't cache this!"},
],
},
{"role": "user", "content": "Hi!"},
]
out_messages, system_blocks = config._transform_system_message(messages.copy())
assert len(out_messages) == 1
assert len(system_blocks) == 3
assert system_blocks[0]["text"] == "Cache this!"
assert "cachePoint" in system_blocks[1]
assert system_blocks[1]["cachePoint"]["type"] == "default"
assert system_blocks[2]["text"] == "Don't cache this!"
# Case 4: Non-system messages are not affected
messages = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
]
out_messages, system_blocks = config._transform_system_message(messages.copy())
assert len(out_messages) == 2
assert out_messages[0]["role"] == "user"
assert out_messages[1]["role"] == "assistant"
assert system_blocks == []
def test_transform_thinking_blocks_with_redacted_content():
thinking_blocks = [
@ -59,3 +138,4 @@ def test_transform_thinking_blocks_with_redacted_content():
assert len(transformed_thinking_blocks) == 2
assert transformed_thinking_blocks[0]["type"] == "thinking"
assert transformed_thinking_blocks[1]["type"] == "redacted_thinking"