forked from phoenix/litellm-mirror
fix(proxy_server.py): track team spend for cached team object
fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
parent
5ad095ad9d
commit
6558abf845
5 changed files with 316 additions and 156 deletions
|
@ -18,6 +18,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
@ -26,6 +27,7 @@ from litellm.caching import DualCache
|
|||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||
from litellm.proxy.proxy_server import chat_completion
|
||||
|
||||
public_key = {
|
||||
"kty": "RSA",
|
||||
|
@ -220,6 +222,70 @@ def prisma_client():
|
|||
return prisma_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def team_token_tuple():
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
)
|
||||
# Get private key in PEM format
|
||||
private_key = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
# Get public key in PEM format
|
||||
public_key = key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
public_key_obj = serialization.load_pem_public_key(
|
||||
public_key, backend=default_backend()
|
||||
)
|
||||
|
||||
# Convert RSA public key object to JWK (JSON Web Key)
|
||||
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
# Assuming the current time is in UTC
|
||||
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||
|
||||
team_id = f"team123_{uuid.uuid4()}"
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": None,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
# But before, you should convert bytes to string
|
||||
private_key_str = private_key.decode("utf-8")
|
||||
|
||||
## team token
|
||||
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
||||
return team_id, token, public_jwk
|
||||
|
||||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
|
@ -750,3 +816,33 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
|||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_cache_update_called():
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
# Use setattr to replace the method on the user_api_key_cache object
|
||||
cache = DualCache()
|
||||
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"user_api_key_cache",
|
||||
cache,
|
||||
)
|
||||
|
||||
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
|
||||
cache.async_get_cache = mock_call_cache
|
||||
# Call the function under test
|
||||
await litellm.proxy.proxy_server.update_cache(
|
||||
token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20
|
||||
) # type: ignore
|
||||
|
||||
await asyncio.sleep(3)
|
||||
mock_call_cache.assert_awaited_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue