mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(google_ai_studio_endpoints.py): support pass-through endpoint for all google ai studio requests
New Feature
This commit is contained in:
parent
668ea6cbc7
commit
29bedae79f
6 changed files with 186 additions and 20 deletions
|
@ -273,6 +273,7 @@ async def pass_through_request(
|
|||
custom_headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
forward_headers: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
import time
|
||||
|
@ -308,23 +309,9 @@ async def pass_through_request(
|
|||
)
|
||||
|
||||
async_client = httpx.AsyncClient()
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=request.query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
# create logging object
|
||||
start_time = time.time()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
|
@ -334,6 +321,7 @@ async def pass_through_request(
|
|||
litellm_call_id=str(uuid.uuid4()),
|
||||
function_id="1245",
|
||||
)
|
||||
|
||||
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
|
@ -354,6 +342,44 @@ async def pass_through_request(
|
|||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params = query_params or request.query_params.__dict__
|
||||
requested_query_params_str = "&".join(
|
||||
f"{k}={v}" for k, v in requested_query_params.items()
|
||||
)
|
||||
|
||||
if "?" in str(url):
|
||||
logging_url = str(url) + "&" + requested_query_params_str
|
||||
else:
|
||||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
"api_base": logging_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
end_time = time.time()
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result="",
|
||||
start_time=start_time,
|
||||
|
@ -431,17 +457,19 @@ def create_pass_through_route(
|
|||
except Exception:
|
||||
verbose_proxy_logger.debug("Defaulting to target being a url.")
|
||||
|
||||
async def endpoint_func(
|
||||
async def endpoint_func( # type: ignore
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
query_params: Optional[dict] = None,
|
||||
):
|
||||
return await pass_through_request(
|
||||
return await pass_through_request( # type: ignore
|
||||
request=request,
|
||||
target=target,
|
||||
custom_headers=custom_headers or {},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
forward_headers=_forward_headers,
|
||||
query_params=query_params,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue