feat(Support-pass-through-for-bedrock-endpoints): Allows pass-through support for bedrock endpoints

This commit is contained in:
Krrish Dholakia 2024-08-17 17:57:43 -07:00
parent f7a2e04426
commit 663a0c1b83
4 changed files with 117 additions and 22 deletions

View file

@ -283,6 +283,7 @@ async def pass_through_request(
target: str,
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
custom_body: Optional[dict] = None,
forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None,
stream: Optional[bool] = None,
@ -300,12 +301,15 @@ async def pass_through_request(
request=request, headers=headers, forward_headers=forward_headers
)
request_body = await request.body()
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
if custom_body:
_parsed_body = custom_body
else:
request_body = await request.body()
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
@ -356,22 +360,24 @@ async def pass_through_request(
# combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__
requested_query_params_str = "&".join(
f"{k}={v}" for k, v in requested_query_params.items()
)
# requested_query_params = query_params or request.query_params.__dict__
# requested_query_params_str = "&".join(
# f"{k}={v}" for k, v in requested_query_params.items()
# )
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
requested_query_params = None
# if "?" in str(url):
# logging_url = str(url) + "&" + requested_query_params_str
# else:
# logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
"api_base": logging_url,
"api_base": str(url),
"headers": headers,
},
)
@ -526,6 +532,7 @@ def create_pass_through_route(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None,
custom_body: Optional[dict] = None,
stream: Optional[
bool
] = None, # if pass-through endpoint is a streaming request
@ -538,6 +545,7 @@ def create_pass_through_route(
forward_headers=_forward_headers,
query_params=query_params,
stream=stream,
custom_body=custom_body,
)
return endpoint_func