(bedrock): Fix usage with Cloudflare AI Gateway, and proxies in general. (#5509)

This commit is contained in:
David Manouchehri 2024-09-04 11:43:01 -05:00 committed by GitHub
parent 949cd51529
commit 3fac0349c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 16 deletions

View file

@ -728,7 +728,7 @@ class BedrockLLM(BaseAWSLLM):
)
### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint(
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
@ -736,8 +736,10 @@ class BedrockLLM(BaseAWSLLM):
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -903,7 +905,7 @@ class BedrockLLM(BaseAWSLLM):
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
@ -917,7 +919,7 @@ class BedrockLLM(BaseAWSLLM):
model=model,
messages=messages,
data=data,
api_base=prepped.url,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -935,7 +937,7 @@ class BedrockLLM(BaseAWSLLM):
model=model,
messages=messages,
data=data,
api_base=prepped.url,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -960,7 +962,7 @@ class BedrockLLM(BaseAWSLLM):
self.client = client
if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post(
url=prepped.url,
url=proxy_endpoint_url,
headers=prepped.headers, # type: ignore
data=data,
stream=stream,
@ -991,7 +993,7 @@ class BedrockLLM(BaseAWSLLM):
return streaming_response
try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1604,15 +1606,17 @@ class BedrockConverseLLM(BaseAWSLLM):
)
### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint(
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -1719,7 +1723,7 @@ class BedrockConverseLLM(BaseAWSLLM):
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
@ -1733,7 +1737,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
data=data,
api_base=prepped.url,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -1751,7 +1755,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
data=data,
api_base=prepped.url,
api_base=proxy_endpoint_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -1772,7 +1776,7 @@ class BedrockConverseLLM(BaseAWSLLM):
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
api_base=proxy_endpoint_url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
@ -1797,7 +1801,7 @@ class BedrockConverseLLM(BaseAWSLLM):
else:
client = client
try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code

View file

@ -5,7 +5,7 @@ Common utilities used across bedrock chat/embedding/image generation
import os
import types
from enum import Enum
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import httpx
@ -729,7 +729,7 @@ def get_runtime_endpoint(
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
) -> str:
) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
@ -744,7 +744,19 @@ def get_runtime_endpoint(
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
return endpoint_url
# Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
class ModelResponseIterator:

View file

@ -393,7 +393,7 @@ class BedrockEmbedding(BaseAWSLLM):
batch_data.append(transformed_request)
### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint(
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=optional_params.pop(
"aws_bedrock_runtime_endpoint", None