Merge pull request #1748 from BerriAI/litellm_custom_auth_fixes

[Fix] user_custom_auth fixes when user passed bad api_keys
This commit is contained in:
Ishaan Jaff 2024-02-01 15:42:39 -08:00 committed by GitHub
commit b6a709fd8d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 157 additions and 4 deletions

View file

@ -233,9 +233,13 @@ def usage_telemetry(
).start() ).start()
def _get_bearer_token(api_key: str): def _get_bearer_token(
assert api_key.startswith("Bearer ") # ensure Bearer token passed in api_key: str,
api_key = api_key.replace("Bearer ", "") # extract the token ):
if api_key.startswith("Bearer "): # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
else:
api_key = ""
return api_key return api_key
@ -253,11 +257,14 @@ async def user_api_key_auth(
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try: try:
if isinstance(api_key, str): if isinstance(api_key, str):
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key) api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ### ### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None: if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ### ### LITELLM-DEFINED AUTH FUNCTION ###
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -265,6 +272,14 @@ async def user_api_key_auth(
else: else:
return UserAPIKeyAuth() return UserAPIKeyAuth()
if api_key is None:
raise Exception("No API Key passed in. api_key is None")
if secrets.compare_digest(api_key, ""):
# missing 'Bearer ' prefix
raise Exception(
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
)
route: str = request.url.path route: str = request.url.path
if route == "/user/auth": if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: if general_settings.get("allow_user_auth", False) == True:

View file

@ -9,8 +9,14 @@ load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:
print(f"api_key: {api_key}") print(f"api_key: {api_key}")
if api_key == "":
raise Exception(
f"CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
)
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key) return UserAPIKeyAuth(api_key=api_key)
raise Exception raise Exception
except: except Exception as e:
if len(str(e)) > 0:
raise e
raise Exception("Failed custom auth") raise Exception("Failed custom auth")

View file

@ -1154,6 +1154,7 @@ async def test_key_name_null(prisma_client):
""" """
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False})
await litellm.proxy.proxy_server.prisma_client.connect() await litellm.proxy.proxy_server.prisma_client.connect()
try: try:
request = GenerateKeyRequest() request = GenerateKeyRequest()
@ -1212,3 +1213,107 @@ async def test_default_key_params(prisma_client):
except Exception as e: except Exception as e:
print("Got Exception", e) print("Got Exception", e)
pytest.fail(f"Got exception {e}") pytest.fail(f"Got exception {e}")
def test_get_bearer_token():
from litellm.proxy.proxy_server import _get_bearer_token
# Test valid Bearer token
api_key = "Bearer valid_token"
result = _get_bearer_token(api_key)
assert result == "valid_token", f"Expected 'valid_token', got '{result}'"
# Test empty API key
api_key = ""
result = _get_bearer_token(api_key)
assert result == "", f"Expected '', got '{result}'"
# Test API key without Bearer prefix
api_key = "invalid_token"
result = _get_bearer_token(api_key)
assert result == "", f"Expected '', got '{result}'"
# Test API key with Bearer prefix in lowercase
api_key = "bearer valid_token"
result = _get_bearer_token(api_key)
assert result == "", f"Expected '', got '{result}'"
# Test API key with Bearer prefix and extra spaces
api_key = " Bearer valid_token "
result = _get_bearer_token(api_key)
assert result == "", f"Expected '', got '{result}'"
# Test API key with Bearer prefix and no token
api_key = "Bearer sk-1234"
result = _get_bearer_token(api_key)
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'"
@pytest.mark.asyncio
async def test_user_api_key_auth(prisma_client):
from litellm.proxy.proxy_server import ProxyException
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
await litellm.proxy.proxy_server.prisma_client.connect()
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Test case: No API Key passed in
try:
await user_api_key_auth(request, api_key=None)
pytest.fail(f"This should have failed!. IT's an invalid key")
except ProxyException as exc:
print(exc.message)
assert (
exc.message == "Authentication Error, No API Key passed in. api_key is None"
)
# Test case: Malformed API Key (missing 'Bearer ' prefix)
try:
await user_api_key_auth(request, api_key="my_token")
pytest.fail(f"This should have failed!. IT's an invalid key")
except ProxyException as exc:
print(exc.message)
assert (
exc.message
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token"
)
# Test case: User passes empty string API Key
try:
await user_api_key_auth(request, api_key="")
pytest.fail(f"This should have failed!. IT's an invalid key")
except ProxyException as exc:
print(exc.message)
assert (
exc.message
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: "
)
@pytest.mark.asyncio
async def test_user_api_key_auth_without_master_key(prisma_client):
# if master key is not set, expect all calls to go through
try:
from litellm.proxy.proxy_server import ProxyException
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", None)
setattr(
litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}
)
await litellm.proxy.proxy_server.prisma_client.connect()
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Test case: No API Key passed in
await user_api_key_auth(request, api_key=None)
await user_api_key_auth(request, api_key="my_token")
await user_api_key_auth(request, api_key="")
await user_api_key_auth(request, api_key="Bearer " + "1234")
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")

View file

@ -65,3 +65,30 @@ def test_custom_auth(client):
assert e.code == 401 assert e.code == 401
assert e.message == "Authentication Error, Failed custom auth" assert e.message == "Authentication Error, Failed custom auth"
pass pass
def test_custom_auth_bearer(client):
try:
# Your test data
test_data = {
"model": "openai-model",
"messages": [
{"role": "user", "content": "hi"},
],
"max_tokens": 10,
}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {"Authorization": f"WITHOUT BEAR Er {token}"}
response = client.post("/chat/completions", json=test_data, headers=headers)
pytest.fail("LiteLLM Proxy test failed. This request should have been rejected")
except Exception as e:
print(vars(e))
print("got an exception")
assert e.code == 401
assert (
e.message
== "Authentication Error, CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
)
pass