Bedrock document processing fixes (#8005)

* refactor(factory.py): refactor async bedrock message transformation to use async get request for image url conversion

improve latency of bedrock call

* test(test_bedrock_completion.py): add unit testing to ensure async image url get called for async bedrock call

* refactor(factory.py): refactor bedrock translation to use BedrockImageProcessor

reduces duplicate code

* fix(factory.py): fix bug not allowing pdf's to be processed

* fix(factory.py): fix bedrock converse document understanding with image url

* docs(bedrock.md): clarify all bedrock document types are supported

* refactor: cleanup redundant test + unused imports

* perf: improve perf with reusable clients

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-28 17:48:32 -08:00 committed by GitHub
parent c2e3986bbc
commit 8eaa5dc797
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 578 additions and 86 deletions

View file

@ -140,7 +140,7 @@ def exception_type( # type: ignore # noqa: PLR0915
"\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa
) # noqa
print( # noqa
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
"LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'." # noqa
) # noqa
print() # noqa

View file

@ -625,10 +625,6 @@ class Logging(LiteLLMLoggingBaseClass):
masked_api_base = api_base
self.model_call_details["litellm_params"]["api_base"] = masked_api_base
verbose_logger.debug(
"PRE-API-CALL ADDITIONAL ARGS: %s", additional_args
)
curl_command = self._get_request_curl_command(
api_base=api_base,
headers=headers,

View file

@ -13,9 +13,10 @@ import litellm
import litellm.types
import litellm.types.llms
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.llms.ollama import OllamaVisionModelObject
from litellm.types.llms.openai import (
AllMessageValues,
@ -2150,6 +2151,12 @@ def stringify_json_tool_call_content(messages: List) -> List:
###### AMAZON BEDROCK #######
import base64
import mimetypes
from cgi import parse_header
import httpx
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
@ -2166,42 +2173,64 @@ from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
def _parse_content_type(content_type: str) -> str:
main_type, _ = parse_header(content_type)
return main_type
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
class BedrockImageProcessor:
"""Handles both sync and async image processing for Bedrock conversations."""
@staticmethod
def _post_call_image_processing(response: httpx.Response) -> Tuple[str, str]:
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
if not content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
f"URL does not contain content-type (content-type: {content_type})"
)
content_type = _parse_content_type(content_type)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
return base64_bytes, content_type
except Exception as e:
raise e
@staticmethod
async def get_image_details_async(image_url) -> Tuple[str, str]:
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptFactory,
params={"concurrent_limit": 1},
)
# Send a GET request to the image URL
response = await client.get(image_url, follow_redirects=True)
response.raise_for_status() # Raise an exception for HTTP errors
def _process_bedrock_converse_image_block(
image_url: str,
) -> BedrockContentBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import re
return BedrockImageProcessor._post_call_image_processing(response)
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
except Exception as e:
raise e
@staticmethod
def get_image_details(image_url) -> Tuple[str, str]:
try:
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL
response = client.get(image_url, follow_redirects=True)
response.raise_for_status() # Raise an exception for HTTP errors
return BedrockImageProcessor._post_call_image_processing(response)
except Exception as e:
raise e
@staticmethod
def _parse_base64_image(image_url: str) -> Tuple[str, str, str]:
"""Parse base64 encoded image data."""
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
@ -2210,37 +2239,102 @@ def _process_bedrock_converse_image_block(
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockSourceBlock(bytes=img_without_base_64)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, mime_type = get_image_details(image_url)
image_format = mime_type.split("/")[1]
return img_without_base_64, mime_type, image_format
@staticmethod
def _validate_format(mime_type: str, image_format: str) -> str:
"""Validate image format and mime type for both images and documents."""
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_doc_formats = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
document_types = ["application", "text"]
is_document = any(mime_type.startswith(doc_type) for doc_type in document_types)
if is_document:
potential_extensions = mimetypes.guess_all_extensions(mime_type)
valid_extensions = [
ext[1:]
for ext in potential_extensions
if ext[1:] in supported_doc_formats
]
if not valid_extensions:
raise ValueError(
f"No supported extensions for MIME type: {mime_type}. Supported formats: {supported_doc_formats}"
)
# Use first valid extension instead of provided image_format
return valid_extensions[0]
else:
if image_format not in supported_image_formats:
raise ValueError(
f"Unsupported image format: {image_format}. Supported formats: {supported_image_formats}"
)
return image_format
@staticmethod
def _create_bedrock_block(
image_bytes: str, mime_type: str, image_format: str
) -> BedrockContentBlock:
"""Create appropriate Bedrock content block based on mime type."""
_blob = BedrockSourceBlock(bytes=image_bytes)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
supported_image_formats = litellm.AmazonConverseConfig().get_supported_image_types()
document_types = ["application", "text"]
is_document = any(mime_type.startswith(doc_type) for doc_type in document_types)
document_types = ["application", "text"]
is_document = any(
mime_type.startswith(document_type) for document_type in document_types
)
if image_format in supported_image_formats:
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif is_document:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
if is_document:
return BedrockContentBlock(
document=BedrockDocumentBlock(
source=_blob,
format=image_format,
name=f"DocumentPDFmessages_{str(uuid.uuid4())}",
)
)
)
else:
return BedrockContentBlock(
image=BedrockImageBlock(source=_blob, format=image_format)
)
@classmethod
def process_image_sync(cls, image_url: str) -> BedrockContentBlock:
"""Synchronous image processing."""
if "base64" in image_url:
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
elif "https:/" in image_url:
img_bytes, mime_type = BedrockImageProcessor.get_image_details(image_url)
image_format = mime_type.split("/")[1]
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string"
)
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
@classmethod
async def process_image_async(cls, image_url: str) -> BedrockContentBlock:
"""Asynchronous image processing."""
if "base64" in image_url:
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
elif "http://" in image_url or "https://" in image_url:
img_bytes, mime_type = await BedrockImageProcessor.get_image_details_async(
image_url
)
image_format = mime_type.split("/")[1]
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string"
)
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
def _convert_to_bedrock_tool_call_invoke(
@ -2662,6 +2756,219 @@ def get_assistant_message_block_or_continue_message(
raise ValueError(f"Unsupported content type: {type(content_block)}")
class BedrockConverseMessagesProcessor:
@staticmethod
def _initial_message_setup(
messages: List,
user_continue_message: Optional[ChatCompletionUserMessage] = None,
) -> List:
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None:
messages.insert(0, user_continue_message)
elif litellm.modify_params:
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
# if final message is assistant message
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
if user_continue_message is not None:
messages.append(user_continue_message)
elif litellm.modify_params:
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
return messages
@staticmethod
async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
messages: List,
model: str,
llm_provider: str,
user_continue_message: Optional[ChatCompletionUserMessage] = None,
assistant_continue_message: Optional[
Union[str, ChatCompletionAssistantMessage]
] = None,
) -> List[BedrockMessageBlock]:
contents: List[BedrockMessageBlock] = []
msg_i = 0
## BASE CASE ##
if len(messages) == 0:
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR
+ "bedrock requires at least one non-system message",
model=model,
llm_provider=llm_provider,
)
# if initial message is assistant message
messages = BedrockConverseMessagesProcessor._initial_message_setup(
messages, user_continue_message
)
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
message_block = get_user_message_block_or_continue_message(
message=messages[msg_i],
user_continue_message=user_continue_message,
)
if isinstance(message_block["content"], list):
_parts: List[BedrockContentBlock] = []
for element in message_block["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
_part = await BedrockImageProcessor.process_image_async( # type: ignore
image_url=image_url
)
_parts.append(_part) # type: ignore
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block=cast(
OpenAIMessageContentListBlock, element
),
block_type="content_block",
)
)
if _cache_point_block is not None:
_parts.append(_cache_point_block)
user_content.extend(_parts)
elif message_block["content"] and isinstance(
message_block["content"], str
):
_part = BedrockContentBlock(text=messages[msg_i]["content"])
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block, block_type="content_block"
)
)
user_content.append(_part)
if _cache_point_block is not None:
user_content.append(_cache_point_block)
msg_i += 1
if user_content:
if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
or litellm.modify_params is True
):
# if last message was a 'user' message, then add a dummy assistant message (bedrock requires alternating roles)
contents = _insert_assistant_continue_message(
messages=contents,
assistant_continue_message=assistant_continue_message,
)
contents.append(
BedrockMessageBlock(role="user", content=user_content)
)
else:
verbose_logger.warning(
"Potential consecutive user/tool blocks. Trying to merge. If error occurs, please set a 'assistant_continue_message' or set 'modify_params=True' to insert a dummy assistant message for bedrock calls."
)
contents[-1]["content"].extend(user_content)
else:
contents.append(
BedrockMessageBlock(role="user", content=user_content)
)
## MERGE CONSECUTIVE TOOL CALL MESSAGES ##
tool_content: List[BedrockContentBlock] = []
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
tool_content.append(tool_call_result)
msg_i += 1
if tool_content:
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
or litellm.modify_params is True
):
# if last message was a 'user' message, then add a dummy assistant message (bedrock requires alternating roles)
contents = _insert_assistant_continue_message(
messages=contents,
assistant_continue_message=assistant_continue_message,
)
contents.append(
BedrockMessageBlock(role="user", content=tool_content)
)
else:
verbose_logger.warning(
"Potential consecutive user/tool blocks. Trying to merge. If error occurs, please set a 'assistant_continue_message' or set 'modify_params=True' to insert a dummy assistant message for bedrock calls."
)
contents[-1]["content"].extend(tool_content)
else:
contents.append(
BedrockMessageBlock(role="user", content=tool_content)
)
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_message_block = (
get_assistant_message_block_or_continue_message(
message=messages[msg_i],
assistant_continue_message=assistant_continue_message,
)
)
_assistant_content = assistant_message_block.get("content", None)
if _assistant_content is not None and isinstance(
_assistant_content, list
):
assistants_parts: List[BedrockContentBlock] = []
for element in _assistant_content:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(
text=element["text"]
)
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
assistants_part = await BedrockImageProcessor.process_image_async( # type: ignore
image_url=image_url
)
assistants_parts.append(assistants_part)
assistant_content.extend(assistants_parts)
elif _assistant_content is not None and isinstance(
_assistant_content, str
):
assistant_content.append(
BedrockContentBlock(text=_assistant_content)
)
_tool_calls = assistant_message_block.get("tool_calls", [])
if _tool_calls:
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(_tool_calls)
)
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider=llm_provider,
)
return contents
def _bedrock_converse_messages_pt( # noqa: PLR0915
messages: List,
model: str,
@ -2726,7 +3033,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
_part = _process_bedrock_converse_image_block( # type: ignore
_part = BedrockImageProcessor.process_image_sync( # type: ignore
image_url=image_url
)
_parts.append(_part) # type: ignore
@ -2825,7 +3132,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
assistants_part = BedrockImageProcessor.process_image_sync( # type: ignore
image_url=image_url
)
assistants_parts.append(assistants_part)