Bedrock document processing fixes (#8005)

* refactor(factory.py): refactor async bedrock message transformation to use async get request for image url conversion

improve latency of bedrock call

* test(test_bedrock_completion.py): add unit testing to ensure async image url get called for async bedrock call

* refactor(factory.py): refactor bedrock translation to use BedrockImageProcessor

reduces duplicate code

* fix(factory.py): fix bug not allowing pdf's to be processed

* fix(factory.py): fix bedrock converse document understanding with image url

* docs(bedrock.md): clarify all bedrock document types are supported

* refactor: cleanup redundant test + unused imports

* perf: improve perf with reusable clients

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-28 17:48:32 -08:00 committed by GitHub
parent c2e3986bbc
commit 8eaa5dc797
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 578 additions and 86 deletions

View file

@ -792,6 +792,16 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
LiteLLM supports Document Understanding for Bedrock models - [AWS Bedrock Docs](https://docs.aws.amazon.com/nova/latest/userguide/modalities-document.html).
:::info
LiteLLM supports ALL Bedrock document types -
E.g.: "pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"
You can also pass these as either `image_url` or `base64`
:::
### url
<Tabs>

View file

@ -140,7 +140,7 @@ def exception_type( # type: ignore # noqa: PLR0915
"\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa
) # noqa
print( # noqa
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
"LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'." # noqa
) # noqa
print() # noqa

View file

@ -625,10 +625,6 @@ class Logging(LiteLLMLoggingBaseClass):
masked_api_base = api_base
self.model_call_details["litellm_params"]["api_base"] = masked_api_base
verbose_logger.debug(
"PRE-API-CALL ADDITIONAL ARGS: %s", additional_args
)
curl_command = self._get_request_curl_command(
api_base=api_base,
headers=headers,

View file

@ -13,9 +13,10 @@ import litellm
import litellm.types
import litellm.types.llms
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.llms.ollama import OllamaVisionModelObject
from litellm.types.llms.openai import (
AllMessageValues,
@ -2150,6 +2151,12 @@ def stringify_json_tool_call_content(messages: List) -> List:
###### AMAZON BEDROCK #######
import base64
import mimetypes
from cgi import parse_header
import httpx
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
@ -2166,42 +2173,64 @@ from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
def _parse_content_type(content_type: str) -> str:
main_type, _ = parse_header(content_type)
return main_type
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
class BedrockImageProcessor:
"""Handles both sync and async image processing for Bedrock conversations."""
@staticmethod
def _post_call_image_processing(response: httpx.Response) -> Tuple[str, str]:
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
if not content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
f"URL does not contain content-type (content-type: {content_type})"
)
content_type = _parse_content_type(content_type)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
return base64_bytes, content_type
@staticmethod
async def get_image_details_async(image_url) -> Tuple[str, str]:
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptFactory,
params={"concurrent_limit": 1},
)
# Send a GET request to the image URL
response = await client.get(image_url, follow_redirects=True)
response.raise_for_status() # Raise an exception for HTTP errors
return BedrockImageProcessor._post_call_image_processing(response)
except Exception as e:
raise e
@staticmethod
def get_image_details(image_url) -> Tuple[str, str]:
try:
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL
response = client.get(image_url, follow_redirects=True)
response.raise_for_status() # Raise an exception for HTTP errors
def _process_bedrock_converse_image_block(
image_url: str,
) -> BedrockContentBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import re
return BedrockImageProcessor._post_call_image_processing(response)
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
except Exception as e:
raise e
@staticmethod
def _parse_base64_image(image_url: str) -> Tuple[str, str, str]:
"""Parse base64 encoded image data."""
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
@ -2210,37 +2239,102 @@ def _process_bedrock_converse_image_block(
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockSourceBlock(bytes=img_without_base_64)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, mime_type = get_image_details(image_url)
image_format = mime_type.split("/")[1]
_blob = BedrockSourceBlock(bytes=image_bytes)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
return img_without_base_64, mime_type, image_format
@staticmethod
def _validate_format(mime_type: str, image_format: str) -> str:
"""Validate image format and mime type for both images and documents."""
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_doc_formats = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
supported_image_formats = litellm.AmazonConverseConfig().get_supported_image_types()
document_types = ["application", "text"]
is_document = any(
mime_type.startswith(document_type) for document_type in document_types
is_document = any(mime_type.startswith(doc_type) for doc_type in document_types)
if is_document:
potential_extensions = mimetypes.guess_all_extensions(mime_type)
valid_extensions = [
ext[1:]
for ext in potential_extensions
if ext[1:] in supported_doc_formats
]
if not valid_extensions:
raise ValueError(
f"No supported extensions for MIME type: {mime_type}. Supported formats: {supported_doc_formats}"
)
if image_format in supported_image_formats:
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif is_document:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
# Use first valid extension instead of provided image_format
return valid_extensions[0]
else:
# Handle the case when the image format is not supported
if image_format not in supported_image_formats:
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
f"Unsupported image format: {image_format}. Supported formats: {supported_image_formats}"
)
return image_format
@staticmethod
def _create_bedrock_block(
image_bytes: str, mime_type: str, image_format: str
) -> BedrockContentBlock:
"""Create appropriate Bedrock content block based on mime type."""
_blob = BedrockSourceBlock(bytes=image_bytes)
document_types = ["application", "text"]
is_document = any(mime_type.startswith(doc_type) for doc_type in document_types)
if is_document:
return BedrockContentBlock(
document=BedrockDocumentBlock(
source=_blob,
format=image_format,
name=f"DocumentPDFmessages_{str(uuid.uuid4())}",
)
)
else:
return BedrockContentBlock(
image=BedrockImageBlock(source=_blob, format=image_format)
)
@classmethod
def process_image_sync(cls, image_url: str) -> BedrockContentBlock:
"""Synchronous image processing."""
if "base64" in image_url:
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
elif "https:/" in image_url:
img_bytes, mime_type = BedrockImageProcessor.get_image_details(image_url)
image_format = mime_type.split("/")[1]
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string"
)
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
@classmethod
async def process_image_async(cls, image_url: str) -> BedrockContentBlock:
"""Asynchronous image processing."""
if "base64" in image_url:
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
elif "http://" in image_url or "https://" in image_url:
img_bytes, mime_type = await BedrockImageProcessor.get_image_details_async(
image_url
)
image_format = mime_type.split("/")[1]
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string"
)
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
def _convert_to_bedrock_tool_call_invoke(
@ -2662,6 +2756,219 @@ def get_assistant_message_block_or_continue_message(
raise ValueError(f"Unsupported content type: {type(content_block)}")
class BedrockConverseMessagesProcessor:
@staticmethod
def _initial_message_setup(
messages: List,
user_continue_message: Optional[ChatCompletionUserMessage] = None,
) -> List:
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None:
messages.insert(0, user_continue_message)
elif litellm.modify_params:
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
# if final message is assistant message
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
if user_continue_message is not None:
messages.append(user_continue_message)
elif litellm.modify_params:
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
return messages
@staticmethod
async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
messages: List,
model: str,
llm_provider: str,
user_continue_message: Optional[ChatCompletionUserMessage] = None,
assistant_continue_message: Optional[
Union[str, ChatCompletionAssistantMessage]
] = None,
) -> List[BedrockMessageBlock]:
contents: List[BedrockMessageBlock] = []
msg_i = 0
## BASE CASE ##
if len(messages) == 0:
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR
+ "bedrock requires at least one non-system message",
model=model,
llm_provider=llm_provider,
)
# if initial message is assistant message
messages = BedrockConverseMessagesProcessor._initial_message_setup(
messages, user_continue_message
)
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
message_block = get_user_message_block_or_continue_message(
message=messages[msg_i],
user_continue_message=user_continue_message,
)
if isinstance(message_block["content"], list):
_parts: List[BedrockContentBlock] = []
for element in message_block["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
_part = await BedrockImageProcessor.process_image_async( # type: ignore
image_url=image_url
)
_parts.append(_part) # type: ignore
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block=cast(
OpenAIMessageContentListBlock, element
),
block_type="content_block",
)
)
if _cache_point_block is not None:
_parts.append(_cache_point_block)
user_content.extend(_parts)
elif message_block["content"] and isinstance(
message_block["content"], str
):
_part = BedrockContentBlock(text=messages[msg_i]["content"])
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block, block_type="content_block"
)
)
user_content.append(_part)
if _cache_point_block is not None:
user_content.append(_cache_point_block)
msg_i += 1
if user_content:
if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
or litellm.modify_params is True
):
# if last message was a 'user' message, then add a dummy assistant message (bedrock requires alternating roles)
contents = _insert_assistant_continue_message(
messages=contents,
assistant_continue_message=assistant_continue_message,
)
contents.append(
BedrockMessageBlock(role="user", content=user_content)
)
else:
verbose_logger.warning(
"Potential consecutive user/tool blocks. Trying to merge. If error occurs, please set a 'assistant_continue_message' or set 'modify_params=True' to insert a dummy assistant message for bedrock calls."
)
contents[-1]["content"].extend(user_content)
else:
contents.append(
BedrockMessageBlock(role="user", content=user_content)
)
## MERGE CONSECUTIVE TOOL CALL MESSAGES ##
tool_content: List[BedrockContentBlock] = []
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
tool_content.append(tool_call_result)
msg_i += 1
if tool_content:
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
or litellm.modify_params is True
):
# if last message was a 'user' message, then add a dummy assistant message (bedrock requires alternating roles)
contents = _insert_assistant_continue_message(
messages=contents,
assistant_continue_message=assistant_continue_message,
)
contents.append(
BedrockMessageBlock(role="user", content=tool_content)
)
else:
verbose_logger.warning(
"Potential consecutive user/tool blocks. Trying to merge. If error occurs, please set a 'assistant_continue_message' or set 'modify_params=True' to insert a dummy assistant message for bedrock calls."
)
contents[-1]["content"].extend(tool_content)
else:
contents.append(
BedrockMessageBlock(role="user", content=tool_content)
)
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_message_block = (
get_assistant_message_block_or_continue_message(
message=messages[msg_i],
assistant_continue_message=assistant_continue_message,
)
)
_assistant_content = assistant_message_block.get("content", None)
if _assistant_content is not None and isinstance(
_assistant_content, list
):
assistants_parts: List[BedrockContentBlock] = []
for element in _assistant_content:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(
text=element["text"]
)
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
assistants_part = await BedrockImageProcessor.process_image_async( # type: ignore
image_url=image_url
)
assistants_parts.append(assistants_part)
assistant_content.extend(assistants_parts)
elif _assistant_content is not None and isinstance(
_assistant_content, str
):
assistant_content.append(
BedrockContentBlock(text=_assistant_content)
)
_tool_calls = assistant_message_block.get("tool_calls", [])
if _tool_calls:
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(_tool_calls)
)
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider=llm_provider,
)
return contents
def _bedrock_converse_messages_pt( # noqa: PLR0915
messages: List,
model: str,
@ -2726,7 +3033,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
_part = _process_bedrock_converse_image_block( # type: ignore
_part = BedrockImageProcessor.process_image_sync( # type: ignore
image_url=image_url
)
_parts.append(_part) # type: ignore
@ -2825,7 +3132,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
assistants_part = BedrockImageProcessor.process_image_sync( # type: ignore
image_url=image_url
)
assistants_parts.append(assistants_part)

View file

@ -10,10 +10,10 @@ from typing import List, Literal, Optional, Tuple, Union, overload
import httpx
import litellm
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.prompt_templates.factory import (
BedrockConverseMessagesProcessor,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
@ -154,6 +154,9 @@ class AmazonConverseConfig:
def get_supported_document_types(self) -> List[str]:
return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
def get_all_supported_content_types(self) -> List[str]:
return self.get_supported_image_types() + self.get_supported_document_types()
def _create_json_tool_call_for_response_format(
self,
json_schema: Optional[dict] = None,
@ -426,14 +429,23 @@ class AmazonConverseConfig:
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = await asyncify(
_bedrock_converse_messages_pt
)(
# bedrock_messages: List[MessageBlock] = await asyncify(
# _bedrock_converse_messages_pt
# )(
# messages=messages,
# model=model,
# llm_provider="bedrock_converse",
# user_continue_message=litellm_params.pop("user_continue_message", None),
# )
bedrock_messages = (
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -28,13 +28,21 @@ class SourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
BedrockImageTypes = Literal["png", "jpeg", "gif", "webp"]
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
format: Union[BedrockImageTypes, str]
source: SourceBlock
BedrockDocumentTypes = Literal[
"pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"
]
class DocumentBlock(TypedDict):
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
format: Union[BedrockDocumentTypes, str]
source: SourceBlock
name: str

View file

@ -18,3 +18,4 @@ class httpxSpecialProvider(str, Enum):
Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"
PassThroughEndpoint = "pass_through_endpoint"
PromptFactory = "prompt_factory"

View file

@ -0,0 +1,130 @@
import ast
import os
import sys
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
import asyncio
import aiohttp
import base64
import time
from typing import Tuple
import statistics
async def asyncify(func, *args, **kwargs):
return await asyncio.to_thread(func, *args, **kwargs)
def get_image_details(image_url) -> Tuple[str, str]:
try:
client = HTTPHandler(concurrent_limit=1)
response = client.get(image_url)
response.raise_for_status()
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
base64_bytes = base64.b64encode(response.content).decode("utf-8")
return base64_bytes, content_type
except Exception as e:
raise e
async def get_image_details_async(image_url) -> Tuple[str, str]:
try:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.get(image_url)
response.raise_for_status()
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
base64_bytes = base64.b64encode(response.content).decode("utf-8")
return base64_bytes, content_type
except Exception as e:
raise e
async def get_image_details_aio(image_url) -> Tuple[str, str]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
response.raise_for_status()
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
content = await response.read()
base64_bytes = base64.b64encode(content).decode("utf-8")
return base64_bytes, content_type
except Exception as e:
raise e
async def test_asyncified(urls: list[str], iterations: int = 3) -> list[float]:
times = []
for _ in range(iterations):
start = time.perf_counter()
await asyncio.gather(*[asyncify(get_image_details, url) for url in urls])
times.append(time.perf_counter() - start)
return times
async def test_async_httpx(urls: list[str], iterations: int = 3) -> list[float]:
times = []
for _ in range(iterations):
start = time.perf_counter()
await asyncio.gather(*[get_image_details_async(url) for url in urls])
times.append(time.perf_counter() - start)
return times
async def test_aiohttp(urls: list[str], iterations: int = 3) -> list[float]:
times = []
for _ in range(iterations):
start = time.perf_counter()
await asyncio.gather(*[get_image_details_aio(url) for url in urls])
times.append(time.perf_counter() - start)
return times
async def run_comparison():
urls = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
] * 150
print("Testing asyncified version...")
asyncified_times = await test_asyncified(urls)
print("Testing async httpx version...")
async_httpx_times = await test_async_httpx(urls)
print("Testing aiohttp version...")
aiohttp_times = await test_aiohttp(urls)
print("\nResults:")
print(
f"Asyncified version - Mean: {statistics.mean(asyncified_times):.3f}s, Std: {statistics.stdev(asyncified_times):.3f}s"
)
print(
f"Async HTTPX version - Mean: {statistics.mean(async_httpx_times):.3f}s, Std: {statistics.stdev(async_httpx_times):.3f}s"
)
print(
f"Aiohttp version - Mean: {statistics.mean(aiohttp_times):.3f}s, Std: {statistics.stdev(aiohttp_times):.3f}s"
)
print(
f"Speed improvement over asyncified: {statistics.mean(asyncified_times)/statistics.mean(aiohttp_times):.2f}x"
)
print(
f"Speed improvement over async httpx: {statistics.mean(async_httpx_times)/statistics.mean(aiohttp_times):.2f}x"
)
if __name__ == "__main__":
asyncio.run(run_comparison())

View file

@ -2405,18 +2405,6 @@ class TestBedrockEmbedding(BaseLLMEmbeddingTest):
] == "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
def test_process_bedrock_converse_image_block():
from litellm.litellm_core_utils.prompt_templates.factory import (
_process_bedrock_converse_image_block,
)
block = _process_bedrock_converse_image_block(
image_url="data:text/plain;base64,base64file"
)
assert block["document"] is not None
@pytest.mark.asyncio
async def test_bedrock_image_url_sync_client():
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
@ -2483,3 +2471,38 @@ def test_bedrock_error_handling_streaming():
assert isinstance(e.value, BedrockError)
assert "Bedrock is unable to process your request." in e.value.message
assert e.value.status_code == 400
@pytest.mark.parametrize(
"image_url",
[
"https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
# "https://raw.githubusercontent.com/datasets/gdp/master/data/gdp.csv",
"https://www.cmu.edu/blackboard/files/evaluate/tests-example.xls",
"http://www.krishdholakia.com/",
# "https://raw.githubusercontent.com/datasets/sample-data/master/README.txt", # invalid url
"https://raw.githubusercontent.com/mdn/content/main/README.md",
],
)
@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_bedrock_document_understanding(image_url):
from litellm import acompletion
litellm._turn_on_debug()
model = "bedrock/us.amazon.nova-pro-v1:0"
image_content = [
{"type": "text", "text": f"What's this file about?"},
{
"type": "image_url",
"image_url": image_url,
},
]
response = await acompletion(
model=model,
messages=[{"role": "user", "content": image_content}],
)
assert response is not None
assert response.choices[0].message.content != ""

View file

@ -721,14 +721,22 @@ def test_stream_chunk_builder_openai_audio_output_usage():
print(f"response usage: {response.usage}")
check_non_streaming_response(response)
print(f"response: {response}")
# Convert both usage objects to dictionaries for easier comparison
usage_dict = usage_obj.model_dump(exclude_none=True)
response_usage_dict = response.usage.model_dump(exclude_none=True)
# Simple dictionary comparison
assert (
usage_dict == response_usage_dict
), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}"
for k, v in usage_obj.model_dump(exclude_none=True).items():
print(k, v)
response_usage_value = getattr(response.usage, k) # type: ignore
print(f"response_usage_value: {response_usage_value}")
print(f"type: {type(response_usage_value)}")
if isinstance(response_usage_value, BaseModel):
response_usage_value_dict = response_usage_value.model_dump(
exclude_none=True
)
if isinstance(v, dict):
for key, value in v.items():
assert response_usage_value_dict[key] == value
else:
assert response_usage_value_dict == v
else:
assert response_usage_value == v
def test_stream_chunk_builder_empty_initial_chunk():