mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #1748 from BerriAI/litellm_custom_auth_fixes
[Fix] user_custom_auth fixes when user passed bad api_keys
This commit is contained in:
commit
b6a709fd8d
4 changed files with 157 additions and 4 deletions
|
@ -233,9 +233,13 @@ def usage_telemetry(
|
|||
).start()
|
||||
|
||||
|
||||
def _get_bearer_token(api_key: str):
|
||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
def _get_bearer_token(
|
||||
api_key: str,
|
||||
):
|
||||
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
else:
|
||||
api_key = ""
|
||||
return api_key
|
||||
|
||||
|
||||
|
@ -253,11 +257,14 @@ async def user_api_key_auth(
|
|||
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
||||
try:
|
||||
if isinstance(api_key, str):
|
||||
passed_in_key = api_key
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth is not None:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
|
||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -265,6 +272,14 @@ async def user_api_key_auth(
|
|||
else:
|
||||
return UserAPIKeyAuth()
|
||||
|
||||
if api_key is None:
|
||||
raise Exception("No API Key passed in. api_key is None")
|
||||
if secrets.compare_digest(api_key, ""):
|
||||
# missing 'Bearer ' prefix
|
||||
raise Exception(
|
||||
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||
)
|
||||
|
||||
route: str = request.url.path
|
||||
if route == "/user/auth":
|
||||
if general_settings.get("allow_user_auth", False) == True:
|
||||
|
|
|
@ -9,8 +9,14 @@ load_dotenv()
|
|||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
print(f"api_key: {api_key}")
|
||||
if api_key == "":
|
||||
raise Exception(
|
||||
f"CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
|
||||
)
|
||||
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
except Exception as e:
|
||||
if len(str(e)) > 0:
|
||||
raise e
|
||||
raise Exception("Failed custom auth")
|
||||
|
|
|
@ -1154,6 +1154,7 @@ async def test_key_name_null(prisma_client):
|
|||
"""
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False})
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
|
@ -1212,3 +1213,107 @@ async def test_default_key_params(prisma_client):
|
|||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
pytest.fail(f"Got exception {e}")
|
||||
|
||||
|
||||
def test_get_bearer_token():
|
||||
from litellm.proxy.proxy_server import _get_bearer_token
|
||||
|
||||
# Test valid Bearer token
|
||||
api_key = "Bearer valid_token"
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "valid_token", f"Expected 'valid_token', got '{result}'"
|
||||
|
||||
# Test empty API key
|
||||
api_key = ""
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "", f"Expected '', got '{result}'"
|
||||
|
||||
# Test API key without Bearer prefix
|
||||
api_key = "invalid_token"
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "", f"Expected '', got '{result}'"
|
||||
|
||||
# Test API key with Bearer prefix in lowercase
|
||||
api_key = "bearer valid_token"
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "", f"Expected '', got '{result}'"
|
||||
|
||||
# Test API key with Bearer prefix and extra spaces
|
||||
api_key = " Bearer valid_token "
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "", f"Expected '', got '{result}'"
|
||||
|
||||
# Test API key with Bearer prefix and no token
|
||||
api_key = "Bearer sk-1234"
|
||||
result = _get_bearer_token(api_key)
|
||||
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth(prisma_client):
|
||||
from litellm.proxy.proxy_server import ProxyException
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
# Test case: No API Key passed in
|
||||
try:
|
||||
await user_api_key_auth(request, api_key=None)
|
||||
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||
except ProxyException as exc:
|
||||
print(exc.message)
|
||||
assert (
|
||||
exc.message == "Authentication Error, No API Key passed in. api_key is None"
|
||||
)
|
||||
|
||||
# Test case: Malformed API Key (missing 'Bearer ' prefix)
|
||||
try:
|
||||
await user_api_key_auth(request, api_key="my_token")
|
||||
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||
except ProxyException as exc:
|
||||
print(exc.message)
|
||||
assert (
|
||||
exc.message
|
||||
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token"
|
||||
)
|
||||
|
||||
# Test case: User passes empty string API Key
|
||||
try:
|
||||
await user_api_key_auth(request, api_key="")
|
||||
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||
except ProxyException as exc:
|
||||
print(exc.message)
|
||||
assert (
|
||||
exc.message
|
||||
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: "
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth_without_master_key(prisma_client):
|
||||
# if master key is not set, expect all calls to go through
|
||||
try:
|
||||
from litellm.proxy.proxy_server import ProxyException
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", None)
|
||||
setattr(
|
||||
litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}
|
||||
)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
# Test case: No API Key passed in
|
||||
|
||||
await user_api_key_auth(request, api_key=None)
|
||||
await user_api_key_auth(request, api_key="my_token")
|
||||
await user_api_key_auth(request, api_key="")
|
||||
await user_api_key_auth(request, api_key="Bearer " + "1234")
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
pytest.fail(f"Got exception {e}")
|
||||
|
|
|
@ -65,3 +65,30 @@ def test_custom_auth(client):
|
|||
assert e.code == 401
|
||||
assert e.message == "Authentication Error, Failed custom auth"
|
||||
pass
|
||||
|
||||
|
||||
def test_custom_auth_bearer(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "openai-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {"Authorization": f"WITHOUT BEAR Er {token}"}
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
pytest.fail("LiteLLM Proxy test failed. This request should have been rejected")
|
||||
except Exception as e:
|
||||
print(vars(e))
|
||||
print("got an exception")
|
||||
assert e.code == 401
|
||||
assert (
|
||||
e.message
|
||||
== "Authentication Error, CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
|
||||
)
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue