From 0cc273d77b85d1ba502067520a0c7ec01e9a9d27 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 13:29:44 -0700 Subject: [PATCH 1/5] feat(pass_through_endpoint.py): support enforcing key rpm limits on pass through endpoints Closes https://github.com/BerriAI/litellm/issues/4698 --- litellm/integrations/custom_logger.py | 1 + litellm/proxy/_new_secret_config.yaml | 13 +-- .../proxy/hooks/parallel_request_limiter.py | 20 ++-- .../pass_through_endpoints.py | 92 +++++++++++++++++-- litellm/proxy/utils.py | 1 + 5 files changed, 105 insertions(+), 22 deletions(-) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 11d2fde8f..5139723ca 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -99,6 +99,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ) -> Optional[ Union[Exception, str, dict] diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index cf4a823c3..5d301ea26 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,11 +12,12 @@ model_list: general_settings: - alerting: ["slack"] - alerting_threshold: 10 master_key: sk-1234 pass_through_endpoints: - - path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server - target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to - headers: # headers to forward to this URL - litellm_user_api_key: "x-my-test-key" \ No newline at end of file + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json \ No newline at end of file diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index a17fcb2c9..e9c8649d0 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,12 +1,16 @@ -from typing import Optional -import litellm, traceback, sys -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm import ModelResponse +import sys +import traceback from datetime import datetime +from typing import Optional + +from fastapi import HTTPException + +import litellm +from litellm import ModelResponse +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth class _PROXY_MaxParallelRequestsHandler(CustomLogger): diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index b13e9834a..3d17ba0d7 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,6 +3,7 @@ import asyncio import json import traceback from base64 import b64encode +from typing import Optional import httpx from fastapi import ( @@ -240,14 +241,28 @@ async def chat_completion_pass_through_endpoint( ) -async def pass_through_request(request: Request, target: str, custom_headers: dict): +async def pass_through_request( + request: Request, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, +): try: + import time + import uuid + + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.proxy_server import proxy_logging_obj url = httpx.URL(target) headers = custom_headers request_body = await request.body() - _parsed_body = ast.literal_eval(request_body.decode("utf-8")) + body_str = request_body.decode() + try: + _parsed_body = ast.literal_eval(body_str) + except: + _parsed_body = json.loads(body_str) verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( @@ -255,6 +270,13 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) ) + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + _parsed_body = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=_parsed_body, + call_type="pass_through_endpoint", + ) + response = await async_client.request( method=request.method, url=url, @@ -267,6 +289,47 @@ async def pass_through_request(request: Request, target: str, custom_headers: di raise HTTPException(status_code=response.status_code, detail=response.text) content = await response.aread() + + ## LOG SUCCESS + start_time = time.time() + end_time = time.time() + # create logging object + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=str(uuid.uuid4()), + function_id="1245", + ) + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints + kwargs = { + "litellm_params": { + "metadata": { + "user_api_key": user_api_key_dict.api_key, + "user_api_key_user_id": user_api_key_dict.user_id, + "user_api_key_team_id": user_api_key_dict.team_id, + "user_api_key_end_user_id": user_api_key_dict.user_id, + } + }, + "call_type": "pass_through_endpoint", + } + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=kwargs["litellm_params"], + call_type="pass_through_endpoint", + ) + + await logging_obj.async_success_handler( + result="", + start_time=start_time, + end_time=end_time, + cache_hit=False, + ) + return Response( content=content, status_code=response.status_code, @@ -274,8 +337,8 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) except Exception as e: verbose_proxy_logger.error( - "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( - str(e) + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() ) ) verbose_proxy_logger.debug(traceback.format_exc()) @@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) -def create_pass_through_route(endpoint, target: str, custom_headers=None): +def create_pass_through_route( + endpoint, target: str, custom_headers: Optional[dict] = None +): # check if target is an adapter.py or a url import uuid @@ -325,8 +390,17 @@ def create_pass_through_route(endpoint, target: str, custom_headers=None): except Exception: verbose_proxy_logger.warning("Defaulting to target being a url.") - async def endpoint_func(request: Request): # type: ignore - return await pass_through_request(request, target, custom_headers) + async def endpoint_func( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + ) return endpoint_func @@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): if _auth is not None and str(_auth).lower() == "true": if premium_user is not True: raise ValueError( - f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + "Error Setting Authentication on Pass Through Endpoint: {}".format( + CommonProxyErrors.not_premium_user.value + ) ) _dependencies = [Depends(user_api_key_auth)] LiteLLMRoutes.openai_routes.value.append(_path) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e129ccdcf..17fc2ac41 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -299,6 +299,7 @@ class ProxyLogging: "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ) -> dict: """ From 55e153556a8bdbbf91bdc357fb6f39ba7a20f4bf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 13:49:20 -0700 Subject: [PATCH 2/5] test(test_pass_through_endpoints.py): add test for rpm limit support --- litellm/proxy/auth/user_api_key_auth.py | 1 - litellm/tests/test_pass_through_endpoints.py | 62 ++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 03a87eb5b..8e79dffbe 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -96,7 +96,6 @@ async def user_api_key_auth( anthropic_api_key_header ), ) -> UserAPIKeyAuth: - from litellm.proxy.proxy_server import ( allowed_routes_check, common_checks, diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 43543ecc7..9a4431e17 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -85,6 +85,68 @@ async def test_pass_through_endpoint_rerank(client): assert response.status_code == 200 +@pytest.mark.parametrize( + "auth, rpm_limit, expected_error_code", + [(True, 0, 429), (True, 1, 200), (False, 0, 401)], +) +@pytest.mark.asyncio +async def test_pass_through_endpoint_rpm_limit( + client, auth, expected_error_code, rpm_limit +): + import litellm + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache + + mock_api_key = "sk-my-test-key" + cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit) + + _cohere_api_key = os.environ.get("COHERE_API_KEY") + + user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) + + proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) + proxy_logging_obj._init_litellm_callbacks() + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") + setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/v1/rerank", + "target": "https://api.cohere.com/v1/rerank", + "auth": auth, + "headers": {"Authorization": f"bearer {_cohere_api_key}"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + _json_data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada." + ], + } + + # Make a request to the pass-through endpoint + response = client.post( + "/v1/rerank", + json=_json_data, + headers={"Authorization": "Bearer {}".format(mock_api_key)}, + ) + + print("JSON response: ", _json_data) + + # Assert the response + assert response.status_code == expected_error_code + + @pytest.mark.asyncio async def test_pass_through_endpoint_anthropic(client): import litellm From 7e769f3b89d75f34e83448db6cc6a36b9871abda Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 14:39:42 -0700 Subject: [PATCH 3/5] fix: fix linting errors --- litellm/proxy/custom_callbacks1.py | 8 +++++--- litellm/proxy/hooks/dynamic_rate_limiter.py | 1 + litellm/tests/test_proxy_reject_logging.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py index 41962c9ab..37e4a6cdb 100644 --- a/litellm/proxy/custom_callbacks1.py +++ b/litellm/proxy/custom_callbacks1.py @@ -1,7 +1,8 @@ -from litellm.integrations.custom_logger import CustomLogger +from typing import Literal, Optional + import litellm -from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache -from typing import Optional, Literal +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy.proxy_server import DualCache, UserAPIKeyAuth # This file includes the custom callbacks for LiteLLM Proxy @@ -27,6 +28,7 @@ class MyCustomHandler( "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ): return data diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 33b5d2eb9..f5621055b 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -198,6 +198,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ) -> Optional[ Union[Exception, str, dict] diff --git a/litellm/tests/test_proxy_reject_logging.py b/litellm/tests/test_proxy_reject_logging.py index 865566d00..d32f7783d 100644 --- a/litellm/tests/test_proxy_reject_logging.py +++ b/litellm/tests/test_proxy_reject_logging.py @@ -66,6 +66,7 @@ class testLogger(CustomLogger): "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ): raise HTTPException( From 77325358b42bdfdefa56117a36979581ea9524ff Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 14:46:56 -0700 Subject: [PATCH 4/5] fix(pass_through_endpoints.py): fix client init --- .../proxy/pass_through_endpoints/pass_through_endpoints.py | 4 ++-- litellm/tests/test_pass_through_endpoints.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 3d17ba0d7..351b19c25 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -23,8 +23,6 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import ProxyException, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -async_client = httpx.AsyncClient() - async def set_env_variables_in_header(custom_headers: dict): """ @@ -277,6 +275,8 @@ async def pass_through_request( call_type="pass_through_endpoint", ) + async_client = httpx.AsyncClient() + response = await async_client.request( method=request.method, url=url, diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 9a4431e17..4f52f3d19 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -90,9 +90,8 @@ async def test_pass_through_endpoint_rerank(client): [(True, 0, 429), (True, 1, 200), (False, 0, 401)], ) @pytest.mark.asyncio -async def test_pass_through_endpoint_rpm_limit( - client, auth, expected_error_code, rpm_limit -): +async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit): + client = TestClient(app) import litellm from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache From a6deb9c350d4306c6129c649856b424b5fbd8537 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 15:10:13 -0700 Subject: [PATCH 5/5] docs(pass_through.md): update doc to specify key rpm limits will be enforced --- docs/my-website/docs/proxy/pass_through.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md index 092b2424b..ffa4f4d76 100644 --- a/docs/my-website/docs/proxy/pass_through.md +++ b/docs/my-website/docs/proxy/pass_through.md @@ -156,6 +156,8 @@ POST /api/public/ingestion HTTP/1.1" 207 Multi-Status Use this if you want the pass through endpoint to honour LiteLLM keys/authentication +This also enforces the key's rpm limits on pass-through endpoints. + Usage - set `auth: true` on the config ```yaml general_settings: @@ -361,4 +363,5 @@ curl --location 'http://0.0.0.0:4000/v1/messages' \ {"role": "user", "content": "Hello, world"} ] }' -``` \ No newline at end of file +``` +