diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat.py index 73e649c5b..972a3abd3 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat.py @@ -728,7 +728,7 @@ class BedrockLLM(BaseAWSLLM): ) ### SET RUNTIME ENDPOINT ### - endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, @@ -736,8 +736,10 @@ class BedrockLLM(BaseAWSLLM): if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke-with-response-stream" else: endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke" sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) @@ -903,7 +905,7 @@ class BedrockLLM(BaseAWSLLM): api_key="", additional_args={ "complete_input_dict": data, - "api_base": prepped.url, + "api_base": proxy_endpoint_url, "headers": prepped.headers, }, ) @@ -917,7 +919,7 @@ class BedrockLLM(BaseAWSLLM): model=model, messages=messages, data=data, - api_base=prepped.url, + api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -935,7 +937,7 @@ class BedrockLLM(BaseAWSLLM): model=model, messages=messages, data=data, - api_base=prepped.url, + api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -960,7 +962,7 @@ class BedrockLLM(BaseAWSLLM): self.client = client if (stream is not None and stream == True) and provider != "ai21": response = self.client.post( - url=prepped.url, + url=proxy_endpoint_url, headers=prepped.headers, # type: ignore data=data, stream=stream, @@ -991,7 +993,7 @@ class BedrockLLM(BaseAWSLLM): return streaming_response try: - response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1604,15 +1606,17 @@ class BedrockConverseLLM(BaseAWSLLM): ) ### SET RUNTIME ENDPOINT ### - endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, ) if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" else: endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) @@ -1719,7 +1723,7 @@ class BedrockConverseLLM(BaseAWSLLM): api_key="", additional_args={ "complete_input_dict": data, - "api_base": prepped.url, + "api_base": proxy_endpoint_url, "headers": prepped.headers, }, ) @@ -1733,7 +1737,7 @@ class BedrockConverseLLM(BaseAWSLLM): model=model, messages=messages, data=data, - api_base=prepped.url, + api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -1751,7 +1755,7 @@ class BedrockConverseLLM(BaseAWSLLM): model=model, messages=messages, data=data, - api_base=prepped.url, + api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -1772,7 +1776,7 @@ class BedrockConverseLLM(BaseAWSLLM): make_call=partial( make_sync_call, client=None, - api_base=prepped.url, + api_base=proxy_endpoint_url, headers=prepped.headers, # type: ignore data=data, model=model, @@ -1797,7 +1801,7 @@ class BedrockConverseLLM(BaseAWSLLM): else: client = client try: - response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index f2032d110..25379474e 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -5,7 +5,7 @@ Common utilities used across bedrock chat/embedding/image generation import os import types from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import httpx @@ -729,7 +729,7 @@ def get_runtime_endpoint( api_base: Optional[str], aws_bedrock_runtime_endpoint: Optional[str], aws_region_name: str, -) -> str: +) -> Tuple[str, str]: env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") if api_base is not None: endpoint_url = api_base @@ -744,7 +744,19 @@ def get_runtime_endpoint( else: endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - return endpoint_url + # Determine proxy_endpoint_url + if env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = env_aws_bedrock_runtime_endpoint + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = aws_bedrock_runtime_endpoint + else: + proxy_endpoint_url = endpoint_url + + return endpoint_url, proxy_endpoint_url class ModelResponseIterator: diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index 6398c2c34..e6a1319b0 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -393,7 +393,7 @@ class BedrockEmbedding(BaseAWSLLM): batch_data.append(transformed_request) ### SET RUNTIME ENDPOINT ### - endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=optional_params.pop( "aws_bedrock_runtime_endpoint", None