From 742e3cbccfdef3d24e29e7db0391c1c3d2cf6b25 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 13 Aug 2024 20:26:24 -0700 Subject: [PATCH] feat(user_api_key_auth.py): support calling langfuse with litellm user_api_key_auth --- litellm/proxy/_new_secret_config.yaml | 10 +- litellm/proxy/auth/user_api_key_auth.py | 43 ++++++- .../pass_through_endpoints.py | 2 +- litellm/tests/test_pass_through_endpoints.py | 112 +++++++++++++++++- 4 files changed, 160 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 87a561e318..bc3e0680f8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,5 +3,11 @@ model_list: litellm_params: model: "*" -litellm_settings: - success_callback: ["langsmith"] \ No newline at end of file +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server + target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9bbbc1a430..3df90f37fa 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -86,7 +86,7 @@ def _get_bearer_token( if api_key.startswith("Bearer "): # ensure Bearer token passed in api_key = api_key.replace("Bearer ", "") # extract the token else: - api_key = "" + api_key = api_key return api_key @@ -138,7 +138,6 @@ async def user_api_key_auth( pass_through_endpoints: Optional[List[dict]] = general_settings.get( "pass_through_endpoints", None ) - if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) @@ -367,6 +366,40 @@ async def user_api_key_auth( parent_otel_span=parent_otel_span, ) #### ELSE #### + + ## CHECK PASS-THROUGH ENDPOINTS ## + if pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if endpoint.get("path", "") == route: + ## IF AUTH DISABLED + if endpoint.get("auth") is not True: + return UserAPIKeyAuth() + ## IF AUTH ENABLED + ### IF CUSTOM PARSER REQUIRED + if ( + endpoint.get("custom_auth_parser") is not None + and endpoint.get("custom_auth_parser") == "langfuse" + ): + """ + - langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'} + - check the langfuse public key if it contains the litellm api key + """ + import base64 + + api_key = api_key.replace("Basic ", "").strip() + decoded_bytes = base64.b64decode(api_key) + decoded_str = decoded_bytes.decode("utf-8") + api_key = decoded_str.split(":")[0] + else: + headers = endpoint.get("headers", None) + if headers is not None: + header_key = headers.get("litellm_user_api_key", "") + if ( + isinstance(request.headers, dict) + and request.headers.get(key=header_key) is not None + ): + api_key = request.headers.get(key=header_key) + if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( @@ -533,7 +566,11 @@ async def user_api_key_auth( if isinstance( api_key, str ): # if generated token, make sure it starts with sk-. - assert api_key.startswith("sk-") # prevent token hashes from being used + assert api_key.startswith( + "sk-" + ), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format( + api_key + ) # prevent token hashes from being used else: verbose_logger.warning( "litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index d71863497f..15129854a3 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -309,7 +309,7 @@ async def pass_through_request( json=_parsed_body, ) - if response.status_code != 200: + if response.status_code >= 300: raise HTTPException(status_code=response.status_code, detail=response.text) content = await response.aread() diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 4f52f3d192..0f57ca68f9 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -1,5 +1,6 @@ import os import sys +from typing import Optional import pytest from fastapi import FastAPI @@ -30,6 +31,7 @@ def client(): async def test_pass_through_endpoint(client, monkeypatch): # Mock the httpx.AsyncClient.request method monkeypatch.setattr("httpx.AsyncClient.request", mock_request) + import litellm # Define a pass-through endpoint pass_through_endpoints = [ @@ -42,6 +44,11 @@ async def test_pass_through_endpoint(client, monkeypatch): # Initialize the pass-through endpoint await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) # Make a request to the pass-through endpoint response = client.post("/test-endpoint", json={"prompt": "Hello, world!"}) @@ -54,6 +61,7 @@ async def test_pass_through_endpoint(client, monkeypatch): @pytest.mark.asyncio async def test_pass_through_endpoint_rerank(client): _cohere_api_key = os.environ.get("COHERE_API_KEY") + import litellm # Define a pass-through endpoint pass_through_endpoints = [ @@ -66,6 +74,11 @@ async def test_pass_through_endpoint_rerank(client): # Initialize the pass-through endpoint await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) _json_data = { "model": "rerank-english-v3.0", @@ -87,7 +100,7 @@ async def test_pass_through_endpoint_rerank(client): @pytest.mark.parametrize( "auth, rpm_limit, expected_error_code", - [(True, 0, 429), (True, 1, 200), (False, 0, 401)], + [(True, 0, 429), (True, 1, 200), (False, 0, 200)], ) @pytest.mark.asyncio async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit): @@ -123,6 +136,11 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li # Initialize the pass-through endpoint await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) _json_data = { "model": "rerank-english-v3.0", @@ -146,6 +164,93 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li assert response.status_code == expected_error_code +@pytest.mark.parametrize( + "auth, rpm_limit, expected_error_code", + [(True, 0, 429), (True, 1, 207), (False, 0, 207)], +) +@pytest.mark.asyncio +async def test_pass_through_endpoint_pass_through_keys_langfuse( + auth, expected_error_code, rpm_limit +): + client = TestClient(app) + import litellm + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache + + mock_api_key = "sk-my-test-key" + cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit) + + _cohere_api_key = os.environ.get("COHERE_API_KEY") + + user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) + + proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) + proxy_logging_obj._init_litellm_callbacks() + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") + setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/api/public/ingestion", + "target": "https://us.cloud.langfuse.com/api/public/ingestion", + "auth": auth, + "custom_auth_parser": "langfuse", + "headers": { + "LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY", + "LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY", + }, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + + _json_data = { + "batch": [ + { + "id": "80e2141f-0ca6-47b7-9c06-dde5e97de690", + "type": "trace-create", + "body": { + "id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865", + "timestamp": "2024-08-14T02:38:56.092950Z", + "name": "test-trace-litellm-proxy-passthrough", + }, + "timestamp": "2024-08-14T02:38:56.093352Z", + } + ], + "metadata": { + "batch_size": 1, + "sdk_integration": "default", + "sdk_name": "python", + "sdk_version": "2.27.0", + "public_key": "anything", + }, + } + + # Make a request to the pass-through endpoint + response = client.post( + "/api/public/ingestion", + json=_json_data, + headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="}, + ) + + print("JSON response: ", _json_data) + + print("RESPONSE RECEIVED - {}".format(response.text)) + + # Assert the response + assert response.status_code == expected_error_code + + @pytest.mark.asyncio async def test_pass_through_endpoint_anthropic(client): import litellm @@ -178,6 +283,11 @@ async def test_pass_through_endpoint_anthropic(client): # Initialize the pass-through endpoint await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) _json_data = { "model": "gpt-3.5-turbo",