litellm-mirror/litellm/tests/test_user_api_key_auth.py
Krrish Dholakia f610fba58f fix(user_api_key_auth.py): handle older user_role's
Fixes issue where older user_role's (e.g. app_user) weren't being recognized. + Adds testing for it
2024-08-05 08:57:06 -07:00

118 lines
3.5 KiB
Python

# What is this?
## Unit tests for user_api_key_auth helper functions
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import List, Optional
from unittest.mock import MagicMock
import pytest
import litellm
class Request:
def __init__(self, client_ip: Optional[str] = None):
self.client = MagicMock()
self.client.host = client_ip
@pytest.mark.parametrize(
"allowed_ips, client_ip, expected_result",
[
(None, "127.0.0.1", True), # No IP restrictions, should be allowed
(["127.0.0.1"], "127.0.0.1", True), # IP in allowed list
(["192.168.1.1"], "127.0.0.1", False), # IP not in allowed list
([], "127.0.0.1", False), # Empty allowed list, no IP should be allowed
(["192.168.1.1", "10.0.0.1"], "10.0.0.1", True), # IP in allowed list
(
["192.168.1.1"],
None,
False,
), # Request with no client IP should not be allowed
],
)
def test_check_valid_ip(
allowed_ips: Optional[List[str]], client_ip: Optional[str], expected_result: bool
):
from litellm.proxy.auth.user_api_key_auth import _check_valid_ip
request = Request(client_ip)
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
@pytest.mark.asyncio
async def test_check_blocked_team():
"""
cached valid_token obj has team_blocked = true
cached team obj has team_blocked = false
assert team is not blocked
"""
import asyncio
import time
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_team_id = "1234"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
team_blocked=True,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
await asyncio.sleep(1)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
@pytest.mark.parametrize(
"user_role", ["app_user", "internal_user", "proxy_admin_viewer"]
)
def test_returned_user_api_key_auth(user_role):
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
user_id_information = [{"user_role": user_role}]
new_obj = _return_user_api_key_auth_obj(
user_id_information,
api_key="hello-world",
parent_otel_span=None,
valid_token_dict={},
route="/chat/completion",
)
if user_role in list(LitellmUserRoles.__annotations__.keys()):
assert new_obj.user_role == user_role
else:
assert new_obj.user_role == "internal_user"