(bedrock): Fix usage with Cloudflare AI Gateway, and proxies in general. (#5509)

This commit is contained in:
David Manouchehri 2024-09-04 11:43:01 -05:00 committed by GitHub
parent 949cd51529
commit 3fac0349c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 16 deletions

View file

@ -728,7 +728,7 @@ class BedrockLLM(BaseAWSLLM):
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
@ -736,8 +736,10 @@ class BedrockLLM(BaseAWSLLM):
if (stream is not None and stream is True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke-with-response-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -903,7 +905,7 @@ class BedrockLLM(BaseAWSLLM):
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": prepped.url, "api_base": proxy_endpoint_url,
"headers": prepped.headers, "headers": prepped.headers,
}, },
) )
@ -917,7 +919,7 @@ class BedrockLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -935,7 +937,7 @@ class BedrockLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -960,7 +962,7 @@ class BedrockLLM(BaseAWSLLM):
self.client = client self.client = client
if (stream is not None and stream == True) and provider != "ai21": if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post( response = self.client.post(
url=prepped.url, url=proxy_endpoint_url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
data=data, data=data,
stream=stream, stream=stream,
@ -991,7 +993,7 @@ class BedrockLLM(BaseAWSLLM):
return streaming_response return streaming_response
try: try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
@ -1604,15 +1606,17 @@ class BedrockConverseLLM(BaseAWSLLM):
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
if (stream is not None and stream is True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse" endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -1719,7 +1723,7 @@ class BedrockConverseLLM(BaseAWSLLM):
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": prepped.url, "api_base": proxy_endpoint_url,
"headers": prepped.headers, "headers": prepped.headers,
}, },
) )
@ -1733,7 +1737,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -1751,7 +1755,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -1772,7 +1776,7 @@ class BedrockConverseLLM(BaseAWSLLM):
make_call=partial( make_call=partial(
make_sync_call, make_sync_call,
client=None, client=None,
api_base=prepped.url, api_base=proxy_endpoint_url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
data=data, data=data,
model=model, model=model,
@ -1797,7 +1801,7 @@ class BedrockConverseLLM(BaseAWSLLM):
else: else:
client = client client = client
try: try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code

View file

@ -5,7 +5,7 @@ Common utilities used across bedrock chat/embedding/image generation
import os import os
import types import types
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union, Tuple
import httpx import httpx
@ -729,7 +729,7 @@ def get_runtime_endpoint(
api_base: Optional[str], api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str], aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str, aws_region_name: str,
) -> str: ) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None: if api_base is not None:
endpoint_url = api_base endpoint_url = api_base
@ -744,7 +744,19 @@ def get_runtime_endpoint(
else: else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
return endpoint_url # Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
class ModelResponseIterator: class ModelResponseIterator:

View file

@ -393,7 +393,7 @@ class BedrockEmbedding(BaseAWSLLM):
batch_data.append(transformed_request) batch_data.append(transformed_request)
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=optional_params.pop( aws_bedrock_runtime_endpoint=optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None