feat(Support-pass-through-for-bedrock-endpoints): Allows pass-through support for bedrock endpoints

This commit is contained in:
Krrish Dholakia 2024-08-17 17:57:43 -07:00
parent f7a2e04426
commit 663a0c1b83
4 changed files with 117 additions and 22 deletions

View file

@ -283,6 +283,7 @@ async def pass_through_request(
target: str, target: str,
custom_headers: dict, custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
custom_body: Optional[dict] = None,
forward_headers: Optional[bool] = False, forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None, query_params: Optional[dict] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
@ -300,12 +301,15 @@ async def pass_through_request(
request=request, headers=headers, forward_headers=forward_headers request=request, headers=headers, forward_headers=forward_headers
) )
request_body = await request.body() if custom_body:
body_str = request_body.decode() _parsed_body = custom_body
try: else:
_parsed_body = ast.literal_eval(body_str) request_body = await request.body()
except Exception: body_str = request_body.decode()
_parsed_body = json.loads(body_str) try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
@ -356,22 +360,24 @@ async def pass_through_request(
# combine url with query params for logging # combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__ # requested_query_params = query_params or request.query_params.__dict__
requested_query_params_str = "&".join( # requested_query_params_str = "&".join(
f"{k}={v}" for k, v in requested_query_params.items() # f"{k}={v}" for k, v in requested_query_params.items()
) # )
if "?" in str(url): requested_query_params = None
logging_url = str(url) + "&" + requested_query_params_str
else: # if "?" in str(url):
logging_url = str(url) + "?" + requested_query_params_str # logging_url = str(url) + "&" + requested_query_params_str
# else:
# logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call( logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": _parsed_body, "complete_input_dict": _parsed_body,
"api_base": logging_url, "api_base": str(url),
"headers": headers, "headers": headers,
}, },
) )
@ -526,6 +532,7 @@ def create_pass_through_route(
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None, query_params: Optional[dict] = None,
custom_body: Optional[dict] = None,
stream: Optional[ stream: Optional[
bool bool
] = None, # if pass-through endpoint is a streaming request ] = None, # if pass-through endpoint is a streaming request
@ -538,6 +545,7 @@ def create_pass_through_route(
forward_headers=_forward_headers, forward_headers=_forward_headers,
query_params=query_params, query_params=query_params,
stream=stream, stream=stream,
custom_body=custom_body,
) )
return endpoint_func return endpoint_func

View file

@ -45,6 +45,16 @@ router = APIRouter()
default_vertex_config = None default_vertex_config = None
def create_request_copy(request: Request):
return {
"method": request.method,
"url": str(request.url),
"headers": dict(request.headers),
"cookies": request.cookies,
"query_params": dict(request.query_params),
}
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) @router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def gemini_proxy_route( async def gemini_proxy_route(
endpoint: str, endpoint: str,
@ -136,3 +146,72 @@ async def cohere_proxy_route(
) )
return received_value return received_value
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def bedrock_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
request_copy = create_request_copy(request)
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
from litellm.llms.bedrock_httpx import BedrockConverseLLM
credentials: Credentials = BedrockConverseLLM().get_credentials()
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"}
# Assuming the body contains JSON data, parse it
try:
data = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail={"error": e})
_request = AWSRequest(
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
)
sigv4.add_auth(_request)
prepped = _request.prepare()
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(prepped.url),
custom_headers=prepped.headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
custom_body=data,
query_params={},
)
return received_value

View file

@ -3375,14 +3375,14 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"bedrock/cohere.command-r-plus-v1:0", # "bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1", # "anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid", # "bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2", # "mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large", # "bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0", # "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14", # "cohere.command-text-v14",
], ],
) )
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])

View file

@ -9053,6 +9053,9 @@ class CustomStreamWrapper:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
index: Optional[int] = None
if "index" in data_json:
index = data_json.get("index")
if "text" in data_json: if "text" in data_json:
text = data_json["text"] text = data_json["text"]
elif "is_finished" in data_json: elif "is_finished" in data_json:
@ -9061,6 +9064,7 @@ class CustomStreamWrapper:
else: else:
raise Exception(data_json) raise Exception(data_json)
return { return {
"index": index,
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
@ -10246,6 +10250,10 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
if completion_obj.get("index") is not None:
model_response.choices[0].index = completion_obj.get(
"index"
)
print_verbose(f"returning model_response: {model_response}") print_verbose(f"returning model_response: {model_response}")
return model_response return model_response
else: else: