diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 63e9e01860..951592b711 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -172,83 +172,113 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li async def test_aaapass_through_endpoint_pass_through_keys_langfuse( auth, expected_error_code, rpm_limit ): + client = TestClient(app) import litellm from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache - mock_api_key = "sk-my-test-key" - cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit) + # Store original values + original_user_api_key_cache = getattr( + litellm.proxy.proxy_server, "user_api_key_cache", None + ) + original_master_key = getattr(litellm.proxy.proxy_server, "master_key", None) + original_prisma_client = getattr(litellm.proxy.proxy_server, "prisma_client", None) + original_proxy_logging_obj = getattr( + litellm.proxy.proxy_server, "proxy_logging_obj", None + ) - _cohere_api_key = os.environ.get("COHERE_API_KEY") + try: - user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) + mock_api_key = "sk-my-test-key" + cache_value = UserAPIKeyAuth( + token=hash_token(mock_api_key), rpm_limit=rpm_limit + ) - proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) - proxy_logging_obj._init_litellm_callbacks() + _cohere_api_key = os.environ.get("COHERE_API_KEY") - setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") - setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) - # Define a pass-through endpoint - pass_through_endpoints = [ - { - "path": "/api/public/ingestion", - "target": "https://cloud.langfuse.com/api/public/ingestion", - "auth": auth, - "custom_auth_parser": "langfuse", - "headers": { - "LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY", - "LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY", + proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) + proxy_logging_obj._init_litellm_callbacks() + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") + setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/api/public/ingestion", + "target": "https://us.cloud.langfuse.com/api/public/ingestion", + "auth": auth, + "custom_auth_parser": "langfuse", + "headers": { + "LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY", + "LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY", + }, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + old_general_settings = general_settings + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + + _json_data = { + "batch": [ + { + "id": "80e2141f-0ca6-47b7-9c06-dde5e97de690", + "type": "trace-create", + "body": { + "id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865", + "timestamp": "2024-08-14T02:38:56.092950Z", + "name": "test-trace-litellm-proxy-passthrough", + }, + "timestamp": "2024-08-14T02:38:56.093352Z", + } + ], + "metadata": { + "batch_size": 1, + "sdk_integration": "default", + "sdk_name": "python", + "sdk_version": "2.27.0", + "public_key": "anything", }, } - ] - # Initialize the pass-through endpoint - await initialize_pass_through_endpoints(pass_through_endpoints) - general_settings: Optional[dict] = ( - getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} - ) - general_settings.update({"pass_through_endpoints": pass_through_endpoints}) - setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + # Make a request to the pass-through endpoint + response = client.post( + "/api/public/ingestion", + json=_json_data, + headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="}, + ) - _json_data = { - "batch": [ - { - "id": "80e2141f-0ca6-47b7-9c06-dde5e97de690", - "type": "trace-create", - "body": { - "id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865", - "timestamp": "2024-08-14T02:38:56.092950Z", - "name": "test-trace-litellm-proxy-passthrough", - }, - "timestamp": "2024-08-14T02:38:56.093352Z", - } - ], - "metadata": { - "batch_size": 1, - "sdk_integration": "default", - "sdk_name": "python", - "sdk_version": "2.27.0", - "public_key": "anything", - }, - } + print("JSON response: ", _json_data) - # Make a request to the pass-through endpoint - response = client.post( - "/api/public/ingestion", - json=_json_data, - headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="}, - ) + print("RESPONSE RECEIVED - {}".format(response.text)) - print("JSON response: ", _json_data) + # Assert the response + assert response.status_code == expected_error_code - print("RESPONSE RECEIVED - {}".format(response.text)) - - # Assert the response - assert response.status_code == expected_error_code + setattr(litellm.proxy.proxy_server, "general_settings", old_general_settings) + finally: + # Reset to original values + setattr( + litellm.proxy.proxy_server, + "user_api_key_cache", + original_user_api_key_cache, + ) + setattr(litellm.proxy.proxy_server, "master_key", original_master_key) + setattr(litellm.proxy.proxy_server, "prisma_client", original_prisma_client) + setattr( + litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj + ) @pytest.mark.asyncio