refactor: make bedrock image transformation requests async (#7840)

* refactor: initial commit for using separate sync vs. async transformation routes for bedrock

ensures no blocking calls e.g. when converting image url to b64

* perf(converse_transformation.py): make bedrock converse transformation async

asyncify's the bedrock message transformation - useful for handling image urls for bedrock

* fix(converse_handler.py): fix logging for async streaming

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-17 20:14:15 -08:00 committed by GitHub
parent 32c8933935
commit 2b58f16fda
8 changed files with 266 additions and 93 deletions

View file

@ -14,7 +14,7 @@ from litellm.llms.custom_httpx.http_handler import (
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper, get_secret
from ..base_aws_llm import BaseAWSLLM
from ..base_aws_llm import BaseAWSLLM, Credentials
from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
@ -41,7 +41,9 @@ def make_sync_call(
)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
raise BedrockError(
status_code=response.status_code, message=str(response.read())
)
if fake_stream:
model_response: (
@ -78,6 +80,7 @@ def make_sync_call(
class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
@ -98,13 +101,13 @@ class BedrockConverseLLM(BaseAWSLLM):
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
@ -112,10 +115,38 @@ class BedrockConverseLLM(BaseAWSLLM):
json_mode: Optional[bool] = False,
) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": dict(prepped.headers),
},
)
completion_stream = await make_call(
client=client,
api_base=api_base,
headers=headers,
headers=dict(prepped.headers),
data=data,
model=model,
messages=messages,
@ -138,17 +169,47 @@ class BedrockConverseLLM(BaseAWSLLM):
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers={},
headers: dict = {},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
headers = dict(prepped.headers)
if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
@ -203,8 +264,6 @@ class BedrockConverseLLM(BaseAWSLLM):
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
@ -237,6 +296,8 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
litellm_params["aws_region_name"] = aws_region_name
### SET REGION NAME ###
if aws_region_name is None:
# check env #
@ -281,7 +342,54 @@ class BedrockConverseLLM(BaseAWSLLM):
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
credentials=credentials,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
credentials=credentials,
) # type: ignore
## TRANSFORMATION ##
@ -292,20 +400,15 @@ class BedrockConverseLLM(BaseAWSLLM):
litellm_params=litellm_params,
)
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=proxy_endpoint_url,
data=data,
headers=headers,
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
@ -317,50 +420,6 @@ class BedrockConverseLLM(BaseAWSLLM):
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None: