feat(pass_through_endpoint.py): support enforcing key rpm limits on pass through endpoints

Closes https://github.com/BerriAI/litellm/issues/4698
This commit is contained in:
Krrish Dholakia 2024-07-13 13:29:44 -07:00
parent caa01d20cb
commit 0cc273d77b
5 changed files with 105 additions and 22 deletions

View file

@ -99,6 +99,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"image_generation", "image_generation",
"moderation", "moderation",
"audio_transcription", "audio_transcription",
"pass_through_endpoint",
], ],
) -> Optional[ ) -> Optional[
Union[Exception, str, dict] Union[Exception, str, dict]

View file

@ -12,11 +12,12 @@ model_list:
general_settings: general_settings:
alerting: ["slack"]
alerting_threshold: 10
master_key: sk-1234 master_key: sk-1234
pass_through_endpoints: pass_through_endpoints:
- path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server - path: "/v1/rerank"
target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to target: "https://api.cohere.com/v1/rerank"
headers: # headers to forward to this URL auth: true # 👈 Key change to use LiteLLM Auth / Keys
litellm_user_api_key: "x-my-test-key" headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
content-type: application/json
accept: application/json

View file

@ -1,12 +1,16 @@
from typing import Optional import sys
import litellm, traceback, sys import traceback
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm import ModelResponse
from datetime import datetime from datetime import datetime
from typing import Optional
from fastapi import HTTPException
import litellm
from litellm import ModelResponse
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_MaxParallelRequestsHandler(CustomLogger): class _PROXY_MaxParallelRequestsHandler(CustomLogger):

View file

@ -3,6 +3,7 @@ import asyncio
import json import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from typing import Optional
import httpx import httpx
from fastapi import ( from fastapi import (
@ -240,14 +241,28 @@ async def chat_completion_pass_through_endpoint(
) )
async def pass_through_request(request: Request, target: str, custom_headers: dict): async def pass_through_request(
request: Request,
target: str,
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
):
try: try:
import time
import uuid
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.proxy.proxy_server import proxy_logging_obj
url = httpx.URL(target) url = httpx.URL(target)
headers = custom_headers headers = custom_headers
request_body = await request.body() request_body = await request.body()
_parsed_body = ast.literal_eval(request_body.decode("utf-8")) body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
@ -255,6 +270,13 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
) )
) )
### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=_parsed_body,
call_type="pass_through_endpoint",
)
response = await async_client.request( response = await async_client.request(
method=request.method, method=request.method,
url=url, url=url,
@ -267,6 +289,47 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
raise HTTPException(status_code=response.status_code, detail=response.text) raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread() content = await response.aread()
## LOG SUCCESS
start_time = time.time()
end_time = time.time()
# create logging object
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
litellm_call_id=str(uuid.uuid4()),
function_id="1245",
)
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
kwargs = {
"litellm_params": {
"metadata": {
"user_api_key": user_api_key_dict.api_key,
"user_api_key_user_id": user_api_key_dict.user_id,
"user_api_key_team_id": user_api_key_dict.team_id,
"user_api_key_end_user_id": user_api_key_dict.user_id,
}
},
"call_type": "pass_through_endpoint",
}
logging_obj.update_environment_variables(
model="unknown",
user="unknown",
optional_params={},
litellm_params=kwargs["litellm_params"],
call_type="pass_through_endpoint",
)
await logging_obj.async_success_handler(
result="",
start_time=start_time,
end_time=end_time,
cache_hit=False,
)
return Response( return Response(
content=content, content=content,
status_code=response.status_code, status_code=response.status_code,
@ -274,8 +337,8 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format(
str(e) str(e), traceback.format_exc()
) )
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
) )
def create_pass_through_route(endpoint, target: str, custom_headers=None): def create_pass_through_route(
endpoint, target: str, custom_headers: Optional[dict] = None
):
# check if target is an adapter.py or a url # check if target is an adapter.py or a url
import uuid import uuid
@ -325,8 +390,17 @@ def create_pass_through_route(endpoint, target: str, custom_headers=None):
except Exception: except Exception:
verbose_proxy_logger.warning("Defaulting to target being a url.") verbose_proxy_logger.warning("Defaulting to target being a url.")
async def endpoint_func(request: Request): # type: ignore async def endpoint_func(
return await pass_through_request(request, target, custom_headers) request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
return await pass_through_request(
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict,
)
return endpoint_func return endpoint_func
@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
if _auth is not None and str(_auth).lower() == "true": if _auth is not None and str(_auth).lower() == "true":
if premium_user is not True: if premium_user is not True:
raise ValueError( raise ValueError(
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" "Error Setting Authentication on Pass Through Endpoint: {}".format(
CommonProxyErrors.not_premium_user.value
)
) )
_dependencies = [Depends(user_api_key_auth)] _dependencies = [Depends(user_api_key_auth)]
LiteLLMRoutes.openai_routes.value.append(_path) LiteLLMRoutes.openai_routes.value.append(_path)

View file

@ -299,6 +299,7 @@ class ProxyLogging:
"image_generation", "image_generation",
"moderation", "moderation",
"audio_transcription", "audio_transcription",
"pass_through_endpoint",
], ],
) -> dict: ) -> dict:
""" """