fix(proxy_server.py): track team spend for cached team object

fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
Krrish Dholakia 2024-06-18 17:10:12 -07:00
parent 5ad095ad9d
commit 6558abf845
5 changed files with 316 additions and 156 deletions

View file

@ -18,6 +18,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
@ -26,6 +27,7 @@ from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
public_key = {
"kty": "RSA",
@ -220,6 +222,70 @@ def prisma_client():
return prisma_client
@pytest.fixture
def team_token_tuple():
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": None,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
return team_id, token, public_jwk
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
@ -750,3 +816,33 @@ async def test_allowed_routes_admin(prisma_client, audience):
result = await user_api_key_auth(request=request, api_key=bearer_token)
except Exception as e:
raise e
from unittest.mock import AsyncMock
import pytest
@pytest.mark.asyncio
async def test_team_cache_update_called():
import litellm
from litellm.proxy.proxy_server import user_api_key_cache
# Use setattr to replace the method on the user_api_key_cache object
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
cache.async_get_cache = mock_call_cache
# Call the function under test
await litellm.proxy.proxy_server.update_cache(
token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20
) # type: ignore
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()