forked from phoenix/litellm-mirror
(bedrock): Fix usage with Cloudflare AI Gateway, and proxies in general. (#5509)
This commit is contained in:
parent
949cd51529
commit
3fac0349c2
3 changed files with 32 additions and 16 deletions
|
@ -728,7 +728,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
|
@ -736,8 +736,10 @@ class BedrockLLM(BaseAWSLLM):
|
|||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{proxy_endpoint_url}/invoke"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
|
@ -903,7 +905,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
@ -917,7 +919,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -935,7 +937,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -960,7 +962,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
self.client = client
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
response = self.client.post(
|
||||
url=prepped.url,
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
stream=stream,
|
||||
|
@ -991,7 +993,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
return streaming_response
|
||||
|
||||
try:
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -1604,15 +1606,17 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
|
@ -1719,7 +1723,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
@ -1733,7 +1737,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -1751,7 +1755,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -1772,7 +1776,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=None,
|
||||
api_base=prepped.url,
|
||||
api_base=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
|
@ -1797,7 +1801,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
|
|
@ -5,7 +5,7 @@ Common utilities used across bedrock chat/embedding/image generation
|
|||
import os
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -729,7 +729,7 @@ def get_runtime_endpoint(
|
|||
api_base: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
aws_region_name: str,
|
||||
) -> str:
|
||||
) -> Tuple[str, str]:
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
|
@ -744,7 +744,19 @@ def get_runtime_endpoint(
|
|||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
return endpoint_url
|
||||
# Determine proxy_endpoint_url
|
||||
if env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
proxy_endpoint_url = endpoint_url
|
||||
|
||||
return endpoint_url, proxy_endpoint_url
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
|
|
|
@ -393,7 +393,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
|||
batch_data.append(transformed_request)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = get_runtime_endpoint(
|
||||
endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue