fix pass through endpoints

This commit is contained in:
Ishaan Jaff 2024-08-21 17:21:22 -07:00
parent 0e1d3804ff
commit bcc0f99476

View file

@ -301,16 +301,19 @@ async def pass_through_request(
request=request, headers=headers, forward_headers=forward_headers request=request, headers=headers, forward_headers=forward_headers
) )
_parsed_body = None
if custom_body: if custom_body:
_parsed_body = custom_body _parsed_body = custom_body
else: else:
request_body = await request.body() request_body = await request.body()
body_str = request_body.decode() if request_body == b"" or request_body is None:
try: _parsed_body = None
_parsed_body = ast.literal_eval(body_str) else:
except Exception: body_str = request_body.decode()
_parsed_body = json.loads(body_str) try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body url, headers, _parsed_body
@ -320,7 +323,7 @@ async def pass_through_request(
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook( _parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=_parsed_body, data=_parsed_body or {},
call_type="pass_through_endpoint", call_type="pass_through_endpoint",
) )
@ -360,15 +363,24 @@ async def pass_through_request(
# combine url with query params for logging # combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__ requested_query_params: Optional[dict] = (
requested_query_params_str = "&".join( query_params or request.query_params.__dict__
f"{k}={v}" for k, v in requested_query_params.items()
) )
if requested_query_params == request.query_params.__dict__:
requested_query_params = None
if "?" in str(url): requested_query_params_str = None
logging_url = str(url) + "&" + requested_query_params_str if requested_query_params:
else: requested_query_params_str = "&".join(
logging_url = str(url) + "?" + requested_query_params_str f"{k}={v}" for k, v in requested_query_params.items()
)
logging_url = str(url)
if requested_query_params_str:
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call( logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -409,6 +421,14 @@ async def pass_through_request(
status_code=response.status_code, status_code=response.status_code,
) )
verbose_proxy_logger.debug("request method: {}".format(request.method))
verbose_proxy_logger.debug("request url: {}".format(url))
verbose_proxy_logger.debug("request headers: {}".format(headers))
verbose_proxy_logger.debug(
"requested_query_params={}".format(requested_query_params)
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
response = await async_client.request( response = await async_client.request(
method=request.method, method=request.method,
url=url, url=url,