forked from phoenix/litellm-mirror
Merge pull request #1748 from BerriAI/litellm_custom_auth_fixes
[Fix] user_custom_auth fixes when user passed bad api_keys
This commit is contained in:
commit
b6a709fd8d
4 changed files with 157 additions and 4 deletions
|
@ -233,9 +233,13 @@ def usage_telemetry(
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
|
||||||
def _get_bearer_token(api_key: str):
|
def _get_bearer_token(
|
||||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
api_key: str,
|
||||||
|
):
|
||||||
|
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
||||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
|
else:
|
||||||
|
api_key = ""
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,11 +257,14 @@ async def user_api_key_auth(
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
||||||
try:
|
try:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
passed_in_key = api_key
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
|
||||||
### USER-DEFINED AUTH FUNCTION ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth is not None:
|
if user_custom_auth is not None:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
|
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -265,6 +272,14 @@ async def user_api_key_auth(
|
||||||
else:
|
else:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise Exception("No API Key passed in. api_key is None")
|
||||||
|
if secrets.compare_digest(api_key, ""):
|
||||||
|
# missing 'Bearer ' prefix
|
||||||
|
raise Exception(
|
||||||
|
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||||
|
)
|
||||||
|
|
||||||
route: str = request.url.path
|
route: str = request.url.path
|
||||||
if route == "/user/auth":
|
if route == "/user/auth":
|
||||||
if general_settings.get("allow_user_auth", False) == True:
|
if general_settings.get("allow_user_auth", False) == True:
|
||||||
|
|
|
@ -9,8 +9,14 @@ load_dotenv()
|
||||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
try:
|
try:
|
||||||
print(f"api_key: {api_key}")
|
print(f"api_key: {api_key}")
|
||||||
|
if api_key == "":
|
||||||
|
raise Exception(
|
||||||
|
f"CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
|
||||||
|
)
|
||||||
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||||
return UserAPIKeyAuth(api_key=api_key)
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
raise Exception
|
raise Exception
|
||||||
except:
|
except Exception as e:
|
||||||
|
if len(str(e)) > 0:
|
||||||
|
raise e
|
||||||
raise Exception("Failed custom auth")
|
raise Exception("Failed custom auth")
|
||||||
|
|
|
@ -1154,6 +1154,7 @@ async def test_key_name_null(prisma_client):
|
||||||
"""
|
"""
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False})
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
try:
|
try:
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
|
@ -1212,3 +1213,107 @@ async def test_default_key_params(prisma_client):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Got Exception", e)
|
print("Got Exception", e)
|
||||||
pytest.fail(f"Got exception {e}")
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bearer_token():
|
||||||
|
from litellm.proxy.proxy_server import _get_bearer_token
|
||||||
|
|
||||||
|
# Test valid Bearer token
|
||||||
|
api_key = "Bearer valid_token"
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "valid_token", f"Expected 'valid_token', got '{result}'"
|
||||||
|
|
||||||
|
# Test empty API key
|
||||||
|
api_key = ""
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "", f"Expected '', got '{result}'"
|
||||||
|
|
||||||
|
# Test API key without Bearer prefix
|
||||||
|
api_key = "invalid_token"
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "", f"Expected '', got '{result}'"
|
||||||
|
|
||||||
|
# Test API key with Bearer prefix in lowercase
|
||||||
|
api_key = "bearer valid_token"
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "", f"Expected '', got '{result}'"
|
||||||
|
|
||||||
|
# Test API key with Bearer prefix and extra spaces
|
||||||
|
api_key = " Bearer valid_token "
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "", f"Expected '', got '{result}'"
|
||||||
|
|
||||||
|
# Test API key with Bearer prefix and no token
|
||||||
|
api_key = "Bearer sk-1234"
|
||||||
|
result = _get_bearer_token(api_key)
|
||||||
|
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_api_key_auth(prisma_client):
|
||||||
|
from litellm.proxy.proxy_server import ProxyException
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
# Test case: No API Key passed in
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request, api_key=None)
|
||||||
|
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||||
|
except ProxyException as exc:
|
||||||
|
print(exc.message)
|
||||||
|
assert (
|
||||||
|
exc.message == "Authentication Error, No API Key passed in. api_key is None"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case: Malformed API Key (missing 'Bearer ' prefix)
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request, api_key="my_token")
|
||||||
|
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||||
|
except ProxyException as exc:
|
||||||
|
print(exc.message)
|
||||||
|
assert (
|
||||||
|
exc.message
|
||||||
|
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case: User passes empty string API Key
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request, api_key="")
|
||||||
|
pytest.fail(f"This should have failed!. IT's an invalid key")
|
||||||
|
except ProxyException as exc:
|
||||||
|
print(exc.message)
|
||||||
|
assert (
|
||||||
|
exc.message
|
||||||
|
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_api_key_auth_without_master_key(prisma_client):
|
||||||
|
# if master key is not set, expect all calls to go through
|
||||||
|
try:
|
||||||
|
from litellm.proxy.proxy_server import ProxyException
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", None)
|
||||||
|
setattr(
|
||||||
|
litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}
|
||||||
|
)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
# Test case: No API Key passed in
|
||||||
|
|
||||||
|
await user_api_key_auth(request, api_key=None)
|
||||||
|
await user_api_key_auth(request, api_key="my_token")
|
||||||
|
await user_api_key_auth(request, api_key="")
|
||||||
|
await user_api_key_auth(request, api_key="Bearer " + "1234")
|
||||||
|
except Exception as e:
|
||||||
|
print("Got Exception", e)
|
||||||
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
|
@ -65,3 +65,30 @@ def test_custom_auth(client):
|
||||||
assert e.code == 401
|
assert e.code == 401
|
||||||
assert e.message == "Authentication Error, Failed custom auth"
|
assert e.message == "Authentication Error, Failed custom auth"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_auth_bearer(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "openai-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {"Authorization": f"WITHOUT BEAR Er {token}"}
|
||||||
|
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||||
|
pytest.fail("LiteLLM Proxy test failed. This request should have been rejected")
|
||||||
|
except Exception as e:
|
||||||
|
print(vars(e))
|
||||||
|
print("got an exception")
|
||||||
|
assert e.code == 401
|
||||||
|
assert (
|
||||||
|
e.message
|
||||||
|
== "Authentication Error, CustomAuth - Malformed API Key passed in. Ensure Key has `Bearer` prefix"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue