mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(bedrock): wrong system prompt transformation (#10120)
* fix(bedrock): wrong system transformation * chore: add one more test case --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
0b63c7a2eb
commit
10257426a2
2 changed files with 95 additions and 13 deletions
|
@ -376,25 +376,27 @@ class AmazonConverseConfig(BaseConfig):
|
|||
system_content_blocks: List[SystemContentBlock] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block: Optional[SystemContentBlock] = None
|
||||
_cache_point_block: Optional[SystemContentBlock] = None
|
||||
if isinstance(message["content"], str) and len(message["content"]) > 0:
|
||||
_system_content_block = SystemContentBlock(text=message["content"])
|
||||
_cache_point_block = self._get_cache_point_block(
|
||||
system_prompt_indices.append(idx)
|
||||
if isinstance(message["content"], str) and message["content"]:
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=message["content"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
message, block_type="system"
|
||||
)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
elif isinstance(message["content"], list):
|
||||
for m in message["content"]:
|
||||
if m.get("type", "") == "text" and len(m["text"]) > 0:
|
||||
_system_content_block = SystemContentBlock(text=m["text"])
|
||||
_cache_point_block = self._get_cache_point_block(
|
||||
if m.get("type") == "text" and m.get("text"):
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=m["text"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
m, block_type="system"
|
||||
)
|
||||
if _system_content_block is not None:
|
||||
system_content_blocks.append(_system_content_block)
|
||||
if _cache_point_block is not None:
|
||||
system_content_blocks.append(_cache_point_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
|
|
|
@ -41,6 +41,85 @@ def test_transform_usage():
|
|||
assert openai_usage._cache_creation_input_tokens == usage["cacheWriteInputTokens"]
|
||||
assert openai_usage._cache_read_input_tokens == usage["cacheReadInputTokens"]
|
||||
|
||||
def test_transform_system_message():
|
||||
config = AmazonConverseConfig()
|
||||
|
||||
# Case 1:
|
||||
# System message popped
|
||||
# User message remains
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
out_messages, system_blocks = config._transform_system_message(messages.copy())
|
||||
assert len(out_messages) == 1
|
||||
assert out_messages[0]["role"] == "user"
|
||||
assert len(system_blocks) == 1
|
||||
assert system_blocks[0]["text"] == "You are a helpful assistant."
|
||||
|
||||
# Case 2: System message with list content (type text)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "System prompt 1"},
|
||||
{"type": "text", "text": "System prompt 2"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Hi!"},
|
||||
]
|
||||
out_messages, system_blocks = config._transform_system_message(messages.copy())
|
||||
assert len(out_messages) == 1
|
||||
assert out_messages[0]["role"] == "user"
|
||||
assert len(system_blocks) == 2
|
||||
assert system_blocks[0]["text"] == "System prompt 1"
|
||||
assert system_blocks[1]["text"] == "System prompt 2"
|
||||
|
||||
# Case 3: System message with cache_control (should add cachePoint)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Cache this!",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{"role": "user", "content": "Hi!"},
|
||||
]
|
||||
out_messages, system_blocks = config._transform_system_message(messages.copy())
|
||||
assert len(out_messages) == 1
|
||||
assert len(system_blocks) == 2
|
||||
assert system_blocks[0]["text"] == "Cache this!"
|
||||
assert "cachePoint" in system_blocks[1]
|
||||
assert system_blocks[1]["cachePoint"]["type"] == "default"
|
||||
|
||||
# Case 3b: System message with two blocks, one with cache_control and one without
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "Cache this!", "cache_control": {"type": "ephemeral"}},
|
||||
{"type": "text", "text": "Don't cache this!"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Hi!"},
|
||||
]
|
||||
out_messages, system_blocks = config._transform_system_message(messages.copy())
|
||||
assert len(out_messages) == 1
|
||||
assert len(system_blocks) == 3
|
||||
assert system_blocks[0]["text"] == "Cache this!"
|
||||
assert "cachePoint" in system_blocks[1]
|
||||
assert system_blocks[1]["cachePoint"]["type"] == "default"
|
||||
assert system_blocks[2]["text"] == "Don't cache this!"
|
||||
|
||||
# Case 4: Non-system messages are not affected
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
out_messages, system_blocks = config._transform_system_message(messages.copy())
|
||||
assert len(out_messages) == 2
|
||||
assert out_messages[0]["role"] == "user"
|
||||
assert out_messages[1]["role"] == "assistant"
|
||||
assert system_blocks == []
|
||||
|
||||
def test_transform_thinking_blocks_with_redacted_content():
|
||||
thinking_blocks = [
|
||||
|
@ -59,3 +138,4 @@ def test_transform_thinking_blocks_with_redacted_content():
|
|||
assert len(transformed_thinking_blocks) == 2
|
||||
assert transformed_thinking_blocks[0]["type"] == "thinking"
|
||||
assert transformed_thinking_blocks[1]["type"] == "redacted_thinking"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue