mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(pass_through_endpoint.py): support enforcing key rpm limits on pass through endpoints
Closes https://github.com/BerriAI/litellm/issues/4698
This commit is contained in:
parent
caa01d20cb
commit
0cc273d77b
5 changed files with 105 additions and 22 deletions
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import json
|
||||
import traceback
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import (
|
||||
|
@ -240,14 +241,28 @@ async def chat_completion_pass_through_endpoint(
|
|||
)
|
||||
|
||||
|
||||
async def pass_through_request(request: Request, target: str, custom_headers: dict):
|
||||
async def pass_through_request(
|
||||
request: Request,
|
||||
target: str,
|
||||
custom_headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
try:
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
url = httpx.URL(target)
|
||||
headers = custom_headers
|
||||
|
||||
request_body = await request.body()
|
||||
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
|
||||
body_str = request_body.decode()
|
||||
try:
|
||||
_parsed_body = ast.literal_eval(body_str)
|
||||
except:
|
||||
_parsed_body = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||
|
@ -255,6 +270,13 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
|||
)
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=_parsed_body,
|
||||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
|
@ -267,6 +289,47 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
|||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
# create logging object
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
litellm_call_id=str(uuid.uuid4()),
|
||||
function_id="1245",
|
||||
)
|
||||
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": user_api_key_dict.api_key,
|
||||
"user_api_key_user_id": user_api_key_dict.user_id,
|
||||
"user_api_key_team_id": user_api_key_dict.team_id,
|
||||
"user_api_key_end_user_id": user_api_key_dict.user_id,
|
||||
}
|
||||
},
|
||||
"call_type": "pass_through_endpoint",
|
||||
}
|
||||
logging_obj.update_environment_variables(
|
||||
model="unknown",
|
||||
user="unknown",
|
||||
optional_params={},
|
||||
litellm_params=kwargs["litellm_params"],
|
||||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result="",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
|
@ -274,8 +337,8 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
|||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format(
|
||||
str(e)
|
||||
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
|||
)
|
||||
|
||||
|
||||
def create_pass_through_route(endpoint, target: str, custom_headers=None):
|
||||
def create_pass_through_route(
|
||||
endpoint, target: str, custom_headers: Optional[dict] = None
|
||||
):
|
||||
# check if target is an adapter.py or a url
|
||||
import uuid
|
||||
|
||||
|
@ -325,8 +390,17 @@ def create_pass_through_route(endpoint, target: str, custom_headers=None):
|
|||
except Exception:
|
||||
verbose_proxy_logger.warning("Defaulting to target being a url.")
|
||||
|
||||
async def endpoint_func(request: Request): # type: ignore
|
||||
return await pass_through_request(request, target, custom_headers)
|
||||
async def endpoint_func(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
return await pass_through_request(
|
||||
request=request,
|
||||
target=target,
|
||||
custom_headers=custom_headers or {},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
||||
|
@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
|||
if _auth is not None and str(_auth).lower() == "true":
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}"
|
||||
"Error Setting Authentication on Pass Through Endpoint: {}".format(
|
||||
CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
)
|
||||
_dependencies = [Depends(user_api_key_auth)]
|
||||
LiteLLMRoutes.openai_routes.value.append(_path)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue