mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor: make bedrock image transformation requests async (#7840)
* refactor: initial commit for using separate sync vs. async transformation routes for bedrock ensures no blocking calls e.g. when converting image url to b64 * perf(converse_transformation.py): make bedrock converse transformation async asyncify's the bedrock message transformation - useful for handling image urls for bedrock * fix(converse_handler.py): fix logging for async streaming * style: cleanup unused imports
This commit is contained in:
parent
32c8933935
commit
2b58f16fda
8 changed files with 266 additions and 93 deletions
|
@ -1056,6 +1056,7 @@ ALL_LITELLM_RESPONSE_TYPES = [
|
|||
]
|
||||
|
||||
from .llms.custom_llm import CustomLLM
|
||||
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
||||
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
||||
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
||||
|
@ -1130,7 +1131,7 @@ from .llms.bedrock.chat.invoke_handler import (
|
|||
AmazonCohereChatConfig,
|
||||
bedrock_tool_name_mappings,
|
||||
)
|
||||
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||
|
||||
from .llms.bedrock.common_utils import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
|
|
@ -12,9 +12,11 @@ from litellm.caching.caching import DualCache
|
|||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
from botocore.credentials import Credentials
|
||||
else:
|
||||
Credentials = Any
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class Boto3CredentialsInfo(BaseModel):
|
||||
|
@ -471,3 +473,32 @@ class BaseAWSLLM:
|
|||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
)
|
||||
|
||||
def get_request_headers(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
data: str,
|
||||
headers: dict,
|
||||
) -> AWSPreparedRequest:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return prepped
|
||||
|
|
|
@ -14,7 +14,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper, get_secret
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||
from ..common_utils import BedrockError
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
|
@ -41,7 +41,9 @@ def make_sync_call(
|
|||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
|
@ -78,6 +80,7 @@ def make_sync_call(
|
|||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -98,13 +101,13 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
|
@ -112,10 +115,38 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
json_mode: Optional[bool] = False,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": dict(prepped.headers),
|
||||
},
|
||||
)
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -138,17 +169,47 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
headers = dict(prepped.headers)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
|
@ -203,8 +264,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
@ -237,6 +296,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
litellm_params["aws_region_name"] = aws_region_name
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
|
@ -281,7 +342,54 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
||||
|
@ -292,20 +400,15 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(_data)
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -317,50 +420,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import List, Literal, Optional, Tuple, Union, overload
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
|
@ -347,14 +348,9 @@ class AmazonConverseConfig:
|
|||
inference_params["topK"] = inference_params.pop("top_k")
|
||||
return InferenceConfig(**inference_params)
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
def _transform_request_helper(
|
||||
self, system_content_blocks: List[SystemContentBlock], optional_params: dict
|
||||
) -> CommonRequestObject:
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
additional_request_keys = []
|
||||
additional_request_params = {}
|
||||
|
@ -364,14 +360,6 @@ class AmazonConverseConfig:
|
|||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
## TRANSFORMATION ##
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
|
@ -408,8 +396,7 @@ class AmazonConverseConfig:
|
|||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
_data: RequestObject = {
|
||||
"messages": bedrock_messages,
|
||||
data: CommonRequestObject = {
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
"inferenceConfig": self._transform_inference_params(
|
||||
|
@ -422,13 +409,65 @@ class AmazonConverseConfig:
|
|||
request_guardrails_config = inference_params.pop("guardrailConfig", None)
|
||||
if request_guardrails_config is not None:
|
||||
guardrail_config = GuardrailConfigBlock(**request_guardrails_config)
|
||||
_data["guardrailConfig"] = guardrail_config
|
||||
data["guardrailConfig"] = guardrail_config
|
||||
|
||||
# Tool Config
|
||||
if bedrock_tool_config is not None:
|
||||
_data["toolConfig"] = bedrock_tool_config
|
||||
data["toolConfig"] = bedrock_tool_config
|
||||
|
||||
return _data
|
||||
return data
|
||||
|
||||
async def _async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
## TRANSFORMATION ##
|
||||
bedrock_messages: List[MessageBlock] = await asyncify(
|
||||
_bedrock_converse_messages_pt
|
||||
)(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
## TRANSFORMATION ##
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
|
|
|
@ -161,16 +161,21 @@ class ContentBlockDeltaEvent(TypedDict, total=False):
|
|||
toolUse: ToolBlockDeltaEvent
|
||||
|
||||
|
||||
class RequestObject(TypedDict, total=False):
|
||||
class CommonRequestObject(
|
||||
TypedDict, total=False
|
||||
): # common request object across sync + async flows
|
||||
additionalModelRequestFields: dict
|
||||
additionalModelResponseFieldPaths: List[str]
|
||||
inferenceConfig: InferenceConfig
|
||||
messages: Required[List[MessageBlock]]
|
||||
system: List[SystemContentBlock]
|
||||
toolConfig: ToolConfigBlock
|
||||
guardrailConfig: Optional[GuardrailConfigBlock]
|
||||
|
||||
|
||||
class RequestObject(CommonRequestObject, total=False):
|
||||
messages: Required[List[MessageBlock]]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
|
|
|
@ -2391,3 +2391,41 @@ def test_process_bedrock_converse_image_block():
|
|||
)
|
||||
|
||||
assert block["document"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_image_url_sync_client():
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
import logging
|
||||
from litellm import verbose_logger
|
||||
|
||||
verbose_logger.setLevel(level=logging.DEBUG)
|
||||
|
||||
litellm._turn_on_debug()
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="bedrock/us.amazon.nova-pro-v1:0",
|
||||
messages=messages,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_post.assert_called_once()
|
||||
|
|
|
@ -1387,7 +1387,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
|||
[
|
||||
# ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
||||
# ["bedrock/cohere.command-r-plus-v1:0", None],
|
||||
# ["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
||||
["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
||||
# ["anthropic.claude-instant-v1", None],
|
||||
# ["mistral.mistral-7b-instruct-v0:2", None],
|
||||
["bedrock/amazon.titan-tg1-large", None],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue