diff --git a/litellm/__init__.py b/litellm/__init__.py index c3ce93097f..9784adbd87 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1056,6 +1056,7 @@ ALL_LITELLM_RESPONSE_TYPES = [ ] from .llms.custom_llm import CustomLLM +from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig from .llms.openai_like.chat.handler import OpenAILikeChatConfig from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig @@ -1130,7 +1131,7 @@ from .llms.bedrock.chat.invoke_handler import ( AmazonCohereChatConfig, bedrock_tool_name_mappings, ) -from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig + from .llms.bedrock.common_utils import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index a03b79106b..8c64203fd7 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -12,9 +12,11 @@ from litellm.caching.caching import DualCache from litellm.secret_managers.main import get_secret, get_secret_str if TYPE_CHECKING: + from botocore.awsrequest import AWSPreparedRequest from botocore.credentials import Credentials else: Credentials = Any + AWSPreparedRequest = Any class Boto3CredentialsInfo(BaseModel): @@ -471,3 +473,32 @@ class BaseAWSLLM: aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, ) + + def get_request_headers( + self, + credentials: Credentials, + aws_region_name: str, + extra_headers: Optional[dict], + endpoint_url: str, + data: str, + headers: dict, + ) -> AWSPreparedRequest: + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + return prepped diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index 0e3b21c373..b6553f8bcc 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -14,7 +14,7 @@ from litellm.llms.custom_httpx.http_handler import ( from litellm.types.utils import ModelResponse from litellm.utils import CustomStreamWrapper, get_secret -from ..base_aws_llm import BaseAWSLLM +from ..base_aws_llm import BaseAWSLLM, Credentials from ..common_utils import BedrockError from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call @@ -41,7 +41,9 @@ def make_sync_call( ) if response.status_code != 200: - raise BedrockError(status_code=response.status_code, message=response.read()) + raise BedrockError( + status_code=response.status_code, message=str(response.read()) + ) if fake_stream: model_response: ( @@ -78,6 +80,7 @@ def make_sync_call( class BedrockConverseLLM(BaseAWSLLM): + def __init__(self) -> None: super().__init__() @@ -98,13 +101,13 @@ class BedrockConverseLLM(BaseAWSLLM): api_base: str, model_response: ModelResponse, print_verbose: Callable, - data: str, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, stream, optional_params: dict, - litellm_params=None, + litellm_params: dict, + credentials: Credentials, logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, @@ -112,10 +115,38 @@ class BedrockConverseLLM(BaseAWSLLM): json_mode: Optional[bool] = False, ) -> CustomStreamWrapper: + request_data = await litellm.AmazonConverseConfig()._async_transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + ) + data = json.dumps(request_data) + + prepped = self.get_request_headers( + credentials=credentials, + aws_region_name=litellm_params.get("aws_region_name") or "us-west-2", + extra_headers=headers, + endpoint_url=api_base, + data=data, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": dict(prepped.headers), + }, + ) + completion_stream = await make_call( client=client, api_base=api_base, - headers=headers, + headers=dict(prepped.headers), data=data, model=model, messages=messages, @@ -138,17 +169,47 @@ class BedrockConverseLLM(BaseAWSLLM): api_base: str, model_response: ModelResponse, print_verbose: Callable, - data: str, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, stream, optional_params: dict, - litellm_params=None, + litellm_params: dict, + credentials: Credentials, logger_fn=None, - headers={}, + headers: dict = {}, client: Optional[AsyncHTTPHandler] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: + + request_data = await litellm.AmazonConverseConfig()._async_transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + ) + data = json.dumps(request_data) + + prepped = self.get_request_headers( + credentials=credentials, + aws_region_name=litellm_params.get("aws_region_name") or "us-west-2", + extra_headers=headers, + endpoint_url=api_base, + data=data, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + headers = dict(prepped.headers) if client is None or not isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: @@ -203,8 +264,6 @@ class BedrockConverseLLM(BaseAWSLLM): client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ): try: - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") @@ -237,6 +296,8 @@ class BedrockConverseLLM(BaseAWSLLM): aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + litellm_params["aws_region_name"] = aws_region_name + ### SET REGION NAME ### if aws_region_name is None: # check env # @@ -281,7 +342,54 @@ class BedrockConverseLLM(BaseAWSLLM): endpoint_url = f"{endpoint_url}/model/{modelId}/converse" proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" - sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + json_mode=json_mode, + fake_stream=fake_stream, + credentials=credentials, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + credentials=credentials, + ) # type: ignore ## TRANSFORMATION ## @@ -292,20 +400,15 @@ class BedrockConverseLLM(BaseAWSLLM): litellm_params=litellm_params, ) data = json.dumps(_data) - ## COMPLETION CALL - headers = {"Content-Type": "application/json"} - if extra_headers is not None: - headers = {"Content-Type": "application/json", **extra_headers} - request = AWSRequest( - method="POST", url=endpoint_url, data=data, headers=headers + prepped = self.get_request_headers( + credentials=credentials, + aws_region_name=aws_region_name, + extra_headers=extra_headers, + endpoint_url=proxy_endpoint_url, + data=data, + headers=headers, ) - sigv4.add_auth(request) - if ( - extra_headers is not None and "Authorization" in extra_headers - ): # prevent sigv4 from overwriting the auth header - request.headers["Authorization"] = extra_headers["Authorization"] - prepped = request.prepare() ## LOGGING logging_obj.pre_call( @@ -317,50 +420,6 @@ class BedrockConverseLLM(BaseAWSLLM): "headers": prepped.headers, }, ) - - ### ROUTING (ASYNC, STREAMING, SYNC) - if acompletion: - if isinstance(client, HTTPHandler): - client = None - if stream is True: - return self.async_streaming( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=True, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - json_mode=json_mode, - fake_stream=fake_stream, - ) # type: ignore - ### ASYNC COMPLETION - return self.async_completion( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, # type: ignore - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - ) # type: ignore - if client is None or isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index b4f1ea3d3c..52c42b790f 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -10,6 +10,7 @@ from typing import List, Literal, Optional, Tuple, Union, overload import httpx import litellm +from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.prompt_templates.factory import ( @@ -347,14 +348,9 @@ class AmazonConverseConfig: inference_params["topK"] = inference_params.pop("top_k") return InferenceConfig(**inference_params) - def _transform_request( - self, - model: str, - messages: List[AllMessageValues], - optional_params: dict, - litellm_params: dict, - ) -> RequestObject: - messages, system_content_blocks = self._transform_system_message(messages) + def _transform_request_helper( + self, system_content_blocks: List[SystemContentBlock], optional_params: dict + ) -> CommonRequestObject: inference_params = copy.deepcopy(optional_params) additional_request_keys = [] additional_request_params = {} @@ -364,14 +360,6 @@ class AmazonConverseConfig: supported_tool_call_params = ["tools", "tool_choice"] supported_guardrail_params = ["guardrailConfig"] inference_params.pop("json_mode", None) # used for handling json_schema - ## TRANSFORMATION ## - - bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( - messages=messages, - model=model, - llm_provider="bedrock_converse", - user_continue_message=litellm_params.pop("user_continue_message", None), - ) # send all model-specific params in 'additional_request_params' for k, v in inference_params.items(): @@ -408,8 +396,7 @@ class AmazonConverseConfig: if tool_choice_values is not None: bedrock_tool_config["toolChoice"] = tool_choice_values - _data: RequestObject = { - "messages": bedrock_messages, + data: CommonRequestObject = { "additionalModelRequestFields": additional_request_params, "system": system_content_blocks, "inferenceConfig": self._transform_inference_params( @@ -422,13 +409,65 @@ class AmazonConverseConfig: request_guardrails_config = inference_params.pop("guardrailConfig", None) if request_guardrails_config is not None: guardrail_config = GuardrailConfigBlock(**request_guardrails_config) - _data["guardrailConfig"] = guardrail_config + data["guardrailConfig"] = guardrail_config # Tool Config if bedrock_tool_config is not None: - _data["toolConfig"] = bedrock_tool_config + data["toolConfig"] = bedrock_tool_config - return _data + return data + + async def _async_transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + ) -> RequestObject: + messages, system_content_blocks = self._transform_system_message(messages) + ## TRANSFORMATION ## + bedrock_messages: List[MessageBlock] = await asyncify( + _bedrock_converse_messages_pt + )( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + + _data: CommonRequestObject = self._transform_request_helper( + system_content_blocks=system_content_blocks, + optional_params=optional_params, + ) + + data: RequestObject = {"messages": bedrock_messages, **_data} + + return data + + def _transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + ) -> RequestObject: + messages, system_content_blocks = self._transform_system_message(messages) + ## TRANSFORMATION ## + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + + _data: CommonRequestObject = self._transform_request_helper( + system_content_blocks=system_content_blocks, + optional_params=optional_params, + ) + + data: RequestObject = {"messages": bedrock_messages, **_data} + + return data def _transform_response( self, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a8699b5eb8..bd863ec0c2 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,4 +2,4 @@ model_list: - model_name: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0 litellm_params: model: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0 - + diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 8d43243dcb..2458d03622 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -161,16 +161,21 @@ class ContentBlockDeltaEvent(TypedDict, total=False): toolUse: ToolBlockDeltaEvent -class RequestObject(TypedDict, total=False): +class CommonRequestObject( + TypedDict, total=False +): # common request object across sync + async flows additionalModelRequestFields: dict additionalModelResponseFieldPaths: List[str] inferenceConfig: InferenceConfig - messages: Required[List[MessageBlock]] system: List[SystemContentBlock] toolConfig: ToolConfigBlock guardrailConfig: Optional[GuardrailConfigBlock] +class RequestObject(CommonRequestObject, total=False): + messages: Required[List[MessageBlock]] + + class GenericStreamingChunk(TypedDict): text: Required[str] tool_use: Optional[ChatCompletionToolCallChunk] diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 0bd31dd5db..52690c242b 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2391,3 +2391,41 @@ def test_process_bedrock_converse_image_block(): ) assert block["document"] is not None + + +@pytest.mark.asyncio +async def test_bedrock_image_url_sync_client(): + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + import logging + from litellm import verbose_logger + + verbose_logger.setLevel(level=logging.DEBUG) + + litellm._turn_on_debug() + client = AsyncHTTPHandler() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + }, + }, + ], + } + ] + + with patch.object(client, "post") as mock_post: + try: + await litellm.acompletion( + model="bedrock/us.amazon.nova-pro-v1:0", + messages=messages, + client=client, + ) + except Exception as e: + print(e) + mock_post.assert_called_once() diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 059e6e4824..793106368d 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -1387,7 +1387,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): [ # ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"], # ["bedrock/cohere.command-r-plus-v1:0", None], - # ["anthropic.claude-3-sonnet-20240229-v1:0", None], + ["anthropic.claude-3-sonnet-20240229-v1:0", None], # ["anthropic.claude-instant-v1", None], # ["mistral.mistral-7b-instruct-v0:2", None], ["bedrock/amazon.titan-tg1-large", None],