refactor: make bedrock image transformation requests async (#7840)

* refactor: initial commit for using separate sync vs. async transformation routes for bedrock

ensures no blocking calls e.g. when converting image url to b64

* perf(converse_transformation.py): make bedrock converse transformation async

asyncify's the bedrock message transformation - useful for handling image urls for bedrock

* fix(converse_handler.py): fix logging for async streaming

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-17 20:14:15 -08:00 committed by GitHub
parent 32c8933935
commit 2b58f16fda
8 changed files with 266 additions and 93 deletions

View file

@ -1056,6 +1056,7 @@ ALL_LITELLM_RESPONSE_TYPES = [
] ]
from .llms.custom_llm import CustomLLM from .llms.custom_llm import CustomLLM
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
from .llms.openai_like.chat.handler import OpenAILikeChatConfig from .llms.openai_like.chat.handler import OpenAILikeChatConfig
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig
@ -1130,7 +1131,7 @@ from .llms.bedrock.chat.invoke_handler import (
AmazonCohereChatConfig, AmazonCohereChatConfig,
bedrock_tool_name_mappings, bedrock_tool_name_mappings,
) )
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
from .llms.bedrock.common_utils import ( from .llms.bedrock.common_utils import (
AmazonTitanConfig, AmazonTitanConfig,
AmazonAI21Config, AmazonAI21Config,

View file

@ -12,9 +12,11 @@ from litellm.caching.caching import DualCache
from litellm.secret_managers.main import get_secret, get_secret_str from litellm.secret_managers.main import get_secret, get_secret_str
if TYPE_CHECKING: if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
else: else:
Credentials = Any Credentials = Any
AWSPreparedRequest = Any
class Boto3CredentialsInfo(BaseModel): class Boto3CredentialsInfo(BaseModel):
@ -471,3 +473,32 @@ class BaseAWSLLM:
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
) )
def get_request_headers(
self,
credentials: Credentials,
aws_region_name: str,
extra_headers: Optional[dict],
endpoint_url: str,
data: str,
headers: dict,
) -> AWSPreparedRequest:
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return prepped

View file

@ -14,7 +14,7 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper, get_secret from litellm.utils import CustomStreamWrapper, get_secret
from ..base_aws_llm import BaseAWSLLM from ..base_aws_llm import BaseAWSLLM, Credentials
from ..common_utils import BedrockError from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
@ -41,7 +41,9 @@ def make_sync_call(
) )
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read()) raise BedrockError(
status_code=response.status_code, message=str(response.read())
)
if fake_stream: if fake_stream:
model_response: ( model_response: (
@ -78,6 +80,7 @@ def make_sync_call(
class BedrockConverseLLM(BaseAWSLLM): class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -98,13 +101,13 @@ class BedrockConverseLLM(BaseAWSLLM):
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
encoding, encoding,
logging_obj, logging_obj,
stream, stream,
optional_params: dict, optional_params: dict,
litellm_params=None, litellm_params: dict,
credentials: Credentials,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
@ -112,10 +115,38 @@ class BedrockConverseLLM(BaseAWSLLM):
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": dict(prepped.headers),
},
)
completion_stream = await make_call( completion_stream = await make_call(
client=client, client=client,
api_base=api_base, api_base=api_base,
headers=headers, headers=dict(prepped.headers),
data=data, data=data,
model=model, model=model,
messages=messages, messages=messages,
@ -138,17 +169,47 @@ class BedrockConverseLLM(BaseAWSLLM):
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
encoding, encoding,
logging_obj, logging_obj,
stream, stream,
optional_params: dict, optional_params: dict,
litellm_params=None, litellm_params: dict,
credentials: Credentials,
logger_fn=None, logger_fn=None,
headers={}, headers: dict = {},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
headers = dict(prepped.headers)
if client is None or not isinstance(client, AsyncHTTPHandler): if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {} _params = {}
if timeout is not None: if timeout is not None:
@ -203,8 +264,6 @@ class BedrockConverseLLM(BaseAWSLLM):
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
): ):
try: try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
@ -237,6 +296,8 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
litellm_params["aws_region_name"] = aws_region_name
### SET REGION NAME ### ### SET REGION NAME ###
if aws_region_name is None: if aws_region_name is None:
# check env # # check env #
@ -281,7 +342,54 @@ class BedrockConverseLLM(BaseAWSLLM):
endpoint_url = f"{endpoint_url}/model/{modelId}/converse" endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) ## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
credentials=credentials,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
credentials=credentials,
) # type: ignore
## TRANSFORMATION ## ## TRANSFORMATION ##
@ -292,20 +400,15 @@ class BedrockConverseLLM(BaseAWSLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
) )
data = json.dumps(_data) data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"} prepped = self.get_request_headers(
if extra_headers is not None: credentials=credentials,
headers = {"Content-Type": "application/json", **extra_headers} aws_region_name=aws_region_name,
request = AWSRequest( extra_headers=extra_headers,
method="POST", url=endpoint_url, data=data, headers=headers endpoint_url=proxy_endpoint_url,
data=data,
headers=headers,
) )
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -317,50 +420,6 @@ class BedrockConverseLLM(BaseAWSLLM):
"headers": prepped.headers, "headers": prepped.headers,
}, },
) )
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler): if client is None or isinstance(client, AsyncHTTPHandler):
_params = {} _params = {}
if timeout is not None: if timeout is not None:

View file

@ -10,6 +10,7 @@ from typing import List, Literal, Optional, Tuple, Union, overload
import httpx import httpx
import litellm import litellm
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
@ -347,14 +348,9 @@ class AmazonConverseConfig:
inference_params["topK"] = inference_params.pop("top_k") inference_params["topK"] = inference_params.pop("top_k")
return InferenceConfig(**inference_params) return InferenceConfig(**inference_params)
def _transform_request( def _transform_request_helper(
self, self, system_content_blocks: List[SystemContentBlock], optional_params: dict
model: str, ) -> CommonRequestObject:
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
additional_request_keys = [] additional_request_keys = []
additional_request_params = {} additional_request_params = {}
@ -364,14 +360,6 @@ class AmazonConverseConfig:
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_guardrail_params = ["guardrailConfig"]
inference_params.pop("json_mode", None) # used for handling json_schema inference_params.pop("json_mode", None) # used for handling json_schema
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
# send all model-specific params in 'additional_request_params' # send all model-specific params in 'additional_request_params'
for k, v in inference_params.items(): for k, v in inference_params.items():
@ -408,8 +396,7 @@ class AmazonConverseConfig:
if tool_choice_values is not None: if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values bedrock_tool_config["toolChoice"] = tool_choice_values
_data: RequestObject = { data: CommonRequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params, "additionalModelRequestFields": additional_request_params,
"system": system_content_blocks, "system": system_content_blocks,
"inferenceConfig": self._transform_inference_params( "inferenceConfig": self._transform_inference_params(
@ -422,13 +409,65 @@ class AmazonConverseConfig:
request_guardrails_config = inference_params.pop("guardrailConfig", None) request_guardrails_config = inference_params.pop("guardrailConfig", None)
if request_guardrails_config is not None: if request_guardrails_config is not None:
guardrail_config = GuardrailConfigBlock(**request_guardrails_config) guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
_data["guardrailConfig"] = guardrail_config data["guardrailConfig"] = guardrail_config
# Tool Config # Tool Config
if bedrock_tool_config is not None: if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config data["toolConfig"] = bedrock_tool_config
return _data return data
async def _async_transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = await asyncify(
_bedrock_converse_messages_pt
)(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data}
return data
def _transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
) -> RequestObject:
messages, system_content_blocks = self._transform_system_message(messages)
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
_data: CommonRequestObject = self._transform_request_helper(
system_content_blocks=system_content_blocks,
optional_params=optional_params,
)
data: RequestObject = {"messages": bedrock_messages, **_data}
return data
def _transform_response( def _transform_response(
self, self,

View file

@ -161,16 +161,21 @@ class ContentBlockDeltaEvent(TypedDict, total=False):
toolUse: ToolBlockDeltaEvent toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False): class CommonRequestObject(
TypedDict, total=False
): # common request object across sync + async flows
additionalModelRequestFields: dict additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str] additionalModelResponseFieldPaths: List[str]
inferenceConfig: InferenceConfig inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock] system: List[SystemContentBlock]
toolConfig: ToolConfigBlock toolConfig: ToolConfigBlock
guardrailConfig: Optional[GuardrailConfigBlock] guardrailConfig: Optional[GuardrailConfigBlock]
class RequestObject(CommonRequestObject, total=False):
messages: Required[List[MessageBlock]]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):
text: Required[str] text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk] tool_use: Optional[ChatCompletionToolCallChunk]

View file

@ -2391,3 +2391,41 @@ def test_process_bedrock_converse_image_block():
) )
assert block["document"] is not None assert block["document"] is not None
@pytest.mark.asyncio
async def test_bedrock_image_url_sync_client():
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import logging
from litellm import verbose_logger
verbose_logger.setLevel(level=logging.DEBUG)
litellm._turn_on_debug()
client = AsyncHTTPHandler()
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
},
},
],
}
]
with patch.object(client, "post") as mock_post:
try:
await litellm.acompletion(
model="bedrock/us.amazon.nova-pro-v1:0",
messages=messages,
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()

View file

@ -1387,7 +1387,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
[ [
# ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"], # ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
# ["bedrock/cohere.command-r-plus-v1:0", None], # ["bedrock/cohere.command-r-plus-v1:0", None],
# ["anthropic.claude-3-sonnet-20240229-v1:0", None], ["anthropic.claude-3-sonnet-20240229-v1:0", None],
# ["anthropic.claude-instant-v1", None], # ["anthropic.claude-instant-v1", None],
# ["mistral.mistral-7b-instruct-v0:2", None], # ["mistral.mistral-7b-instruct-v0:2", None],
["bedrock/amazon.titan-tg1-large", None], ["bedrock/amazon.titan-tg1-large", None],