From 55e153556a8bdbbf91bdc357fb6f39ba7a20f4bf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jul 2024 13:49:20 -0700 Subject: [PATCH] test(test_pass_through_endpoints.py): add test for rpm limit support --- litellm/proxy/auth/user_api_key_auth.py | 1 - litellm/tests/test_pass_through_endpoints.py | 62 ++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 03a87eb5bf..8e79dffbe7 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -96,7 +96,6 @@ async def user_api_key_auth( anthropic_api_key_header ), ) -> UserAPIKeyAuth: - from litellm.proxy.proxy_server import ( allowed_routes_check, common_checks, diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 43543ecc76..9a4431e176 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -85,6 +85,68 @@ async def test_pass_through_endpoint_rerank(client): assert response.status_code == 200 +@pytest.mark.parametrize( + "auth, rpm_limit, expected_error_code", + [(True, 0, 429), (True, 1, 200), (False, 0, 401)], +) +@pytest.mark.asyncio +async def test_pass_through_endpoint_rpm_limit( + client, auth, expected_error_code, rpm_limit +): + import litellm + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache + + mock_api_key = "sk-my-test-key" + cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit) + + _cohere_api_key = os.environ.get("COHERE_API_KEY") + + user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) + + proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) + proxy_logging_obj._init_litellm_callbacks() + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") + setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/v1/rerank", + "target": "https://api.cohere.com/v1/rerank", + "auth": auth, + "headers": {"Authorization": f"bearer {_cohere_api_key}"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + _json_data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada." + ], + } + + # Make a request to the pass-through endpoint + response = client.post( + "/v1/rerank", + json=_json_data, + headers={"Authorization": "Bearer {}".format(mock_api_key)}, + ) + + print("JSON response: ", _json_data) + + # Assert the response + assert response.status_code == expected_error_code + + @pytest.mark.asyncio async def test_pass_through_endpoint_anthropic(client): import litellm