feat(google_ai_studio_endpoints.py): support pass-through endpoint for all google ai studio requests

New Feature
This commit is contained in:
Krrish Dholakia 2024-08-17 10:46:59 -07:00
parent 668ea6cbc7
commit 29bedae79f
6 changed files with 186 additions and 20 deletions

View file

@ -273,6 +273,7 @@ async def pass_through_request(
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None,
):
try:
import time
@ -308,23 +309,9 @@ async def pass_through_request(
)
async_client = httpx.AsyncClient()
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=request.query_params,
json=_parsed_body,
)
if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
## LOG SUCCESS
start_time = time.time()
end_time = time.time()
# create logging object
start_time = time.time()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -334,6 +321,7 @@ async def pass_through_request(
litellm_call_id=str(uuid.uuid4()),
function_id="1245",
)
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
kwargs = {
"litellm_params": {
@ -354,6 +342,44 @@ async def pass_through_request(
call_type="pass_through_endpoint",
)
# combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__
requested_query_params_str = "&".join(
f"{k}={v}" for k, v in requested_query_params.items()
)
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
"api_base": logging_url,
"headers": headers,
},
)
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
## LOG SUCCESS
end_time = time.time()
await logging_obj.async_success_handler(
result="",
start_time=start_time,
@ -431,17 +457,19 @@ def create_pass_through_route(
except Exception:
verbose_proxy_logger.debug("Defaulting to target being a url.")
async def endpoint_func(
async def endpoint_func( # type: ignore
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None,
):
return await pass_through_request(
return await pass_through_request( # type: ignore
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers,
query_params=query_params,
)
return endpoint_func