fix(proxy_server.py): track team spend for cached team object

fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
Krrish Dholakia 2024-06-18 17:10:12 -07:00
parent 5ad095ad9d
commit 6558abf845
5 changed files with 316 additions and 156 deletions

View file

@ -73,16 +73,12 @@ assistant_settings:
router_settings:
enable_pre_call_checks: true
litellm_settings:
callbacks: ["lago"]
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
cache: true
json_logs: true
general_settings:
alerting: ["slack"]
enable_jwt_auth: True
litellm_jwtauth:
team_id_jwt_field: "client_id"
# key_management_system: "aws_kms"
# key_management_settings:
# hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -8,21 +8,22 @@ Run checks for:
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_TeamTable,
LiteLLMRoutes,
LiteLLM_OrganizationTable,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
)
from typing import Optional, Literal, TYPE_CHECKING, Any
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.caching import DualCache
import litellm
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from datetime import datetime
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -110,7 +111,7 @@ def common_checks(
# Enterprise ONLY Feature
# we already validate if user is premium_user when reading the config
# Add an extra premium_usercheck here too, just incase
from litellm.proxy.proxy_server import premium_user, CommonProxyErrors
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
@ -364,7 +365,8 @@ async def get_team_object(
)
# check if in cache
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
key = "team_id:{}".format(team_id)
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
@ -381,7 +383,7 @@ async def get_team_object(
_response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
await user_api_key_cache.async_set_cache(key=key, value=_response)
return _response
except Exception as e:

View file

@ -120,7 +120,7 @@ async def user_api_key_auth(
)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key)
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
@ -140,7 +140,7 @@ async def user_api_key_auth(
# check if public endpoint
return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
if general_settings.get("enable_jwt_auth", False) == True:
if general_settings.get("enable_jwt_auth", False) is True:
is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
@ -177,7 +177,7 @@ async def user_api_key_auth(
token=jwt_valid_token, default_value=None
)
if team_id is None and jwt_handler.is_required_team_id() == True:
if team_id is None and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
@ -190,7 +190,7 @@ async def user_api_key_auth(
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
if is_allowed is False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(

View file

@ -1,12 +1,26 @@
import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING
import secrets, subprocess
import hashlib, uuid
import warnings
import ast
import asyncio
import copy
import hashlib
import importlib
import inspect
import os
import platform
import random
import re
import secrets
import shutil
import subprocess
import sys
import threading
import time
import traceback
import uuid
import warnings
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
import requests
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -34,11 +48,12 @@ sys.path.insert(
) # Adds the parent directory to the system path - for litellm local dev
try:
import fastapi
import backoff
import yaml # type: ignore
import orjson
import logging
import backoff
import fastapi
import orjson
import yaml # type: ignore
from apscheduler.schedulers.asyncio import AsyncIOScheduler
except ImportError as e:
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
@ -85,60 +100,31 @@ def generate_feedback_box():
print() # noqa
import litellm
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.utils import (
PrismaClient,
DBClient,
get_instance_fn,
ProxyLogging,
_cache_user_row,
send_email,
get_logging_payload,
reset_budget,
hash_token,
html_form,
missing_keys_html_form,
_is_valid_team_configs,
_is_projected_spend_over_limit,
_get_projected_spend_over_limit,
update_spend,
encrypt_value,
decrypt_value,
get_error_message_str,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm import (
CreateBatchRequest,
RetrieveBatchRequest,
ListBatchRequest,
CancelBatchRequest,
CreateFileRequest,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import (
LiteLLM_Params,
Deployment,
updateDeployment,
ModelGroupInfo,
AssistantsTypedDict,
import litellm
from litellm import (
CancelBatchRequest,
CreateBatchRequest,
CreateFileRequest,
ListBatchRequest,
RetrieveBatchRequest,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -148,78 +134,105 @@ from litellm.proxy.auth.model_checks import (
get_team_models,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
router as spend_management_router,
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
_duration_in_seconds,
delete_verification_token,
generate_key_helper_fn,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
router as key_management_router,
_duration_in_seconds,
generate_key_helper_fn,
delete_verification_token,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
user_update,
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
router as spend_management_router,
)
from litellm.proxy.utils import (
DBClient,
PrismaClient,
ProxyLogging,
_cache_user_row,
_get_projected_spend_over_limit,
_is_projected_spend_over_limit,
_is_valid_team_configs,
decrypt_value,
encrypt_value,
get_error_message_str,
get_instance_fn,
get_logging_payload,
hash_token,
html_form,
missing_keys_html_form,
reset_budget,
send_email,
update_spend,
)
from litellm.router import (
AssistantsTypedDict,
Deployment,
LiteLLM_Params,
ModelGroupInfo,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.openai import HttpxBinaryResponseContent
try:
from litellm._version import version
except:
version = "0.0.0"
litellm.suppress_debug_info = True
from fastapi import (
FastAPI,
Request,
HTTPException,
status,
Path,
Depends,
Header,
Response,
Form,
UploadFile,
File,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
StreamingResponse,
FileResponse,
ORJSONResponse,
JSONResponse,
)
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.security.api_key import APIKeyHeader
import json
import logging
from typing import Union
from fastapi import (
Depends,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
ORJSONResponse,
RedirectResponse,
StreamingResponse,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.api_key import APIKeyHeader
from fastapi.staticfiles import StaticFiles
# import enterprise folder
try:
# when using litellm cli
@ -488,8 +501,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return
try:
from azure.keyvault.secrets import SecretClient
from azure.identity import ClientSecretCredential
from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
@ -655,6 +668,7 @@ async def _PROXY_track_cost_callback(
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
@ -897,6 +911,7 @@ async def update_cache(
token: Optional[str],
user_id: Optional[str],
end_user_id: Optional[str],
team_id: Optional[str],
response_cost: Optional[float],
):
"""
@ -993,6 +1008,7 @@ async def update_cache(
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
### UPDATE USER SPEND ###
async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
user_ids = [user_id]
@ -1054,6 +1070,7 @@ async def update_cache(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
)
### UPDATE END-USER SPEND ###
async def _update_end_user_cache():
if end_user_id is None or response_cost is None:
return
@ -1102,14 +1119,59 @@ async def update_cache(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
### UPDATE TEAM SPEND ###
async def _update_team_cache():
if team_id is None or response_cost is None:
return
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
if existing_spend_obj is None:
return
verbose_proxy_logger.debug(
f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
)
if existing_spend_obj is None:
existing_spend: Optional[float] = 0.0
else:
if isinstance(existing_spend_obj, dict):
existing_spend = existing_spend_obj["spend"]
else:
existing_spend = existing_spend_obj.spend
if existing_spend is None:
existing_spend = 0.0
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
else:
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
except Exception as e:
verbose_proxy_logger.error(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
if token is not None and response_cost is not None:
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
asyncio.create_task(_update_user_cache())
if user_id is not None:
asyncio.create_task(_update_user_cache())
if end_user_id is not None:
asyncio.create_task(_update_end_user_cache())
if team_id is not None:
asyncio.create_task(_update_team_cache())
def run_ollama_serve():
try:
@ -2297,25 +2359,27 @@ async def initialize(
user_model = model
user_debug = debug
if debug == True: # this needs to be first, so users can see Router init debugg
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
)
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug == True:
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
)
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
@ -2324,9 +2388,10 @@ async def initialize(
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting != None:
if litellm_log_setting.upper() == "INFO":
from litellm._logging import verbose_router_logger, verbose_proxy_logger
import logging
from litellm._logging import verbose_proxy_logger, verbose_router_logger
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_router_logger.setLevel(
@ -2336,9 +2401,10 @@ async def initialize(
level=logging.INFO
) # set proxy logs to info
elif litellm_log_setting.upper() == "DEBUG":
from litellm._logging import verbose_router_logger, verbose_proxy_logger
import logging
from litellm._logging import verbose_proxy_logger, verbose_router_logger
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to info
@ -7036,7 +7102,7 @@ async def google_login(request: Request):
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -7611,7 +7677,7 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")

View file

@ -18,6 +18,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
@ -26,6 +27,7 @@ from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
public_key = {
"kty": "RSA",
@ -220,6 +222,70 @@ def prisma_client():
return prisma_client
@pytest.fixture
def team_token_tuple():
import json
import uuid
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import Request
from starlette.datastructures import URL
import litellm
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": None,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
return team_id, token, public_jwk
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
@ -750,3 +816,33 @@ async def test_allowed_routes_admin(prisma_client, audience):
result = await user_api_key_auth(request=request, api_key=bearer_token)
except Exception as e:
raise e
from unittest.mock import AsyncMock
import pytest
@pytest.mark.asyncio
async def test_team_cache_update_called():
import litellm
from litellm.proxy.proxy_server import user_api_key_cache
# Use setattr to replace the method on the user_api_key_cache object
cache = DualCache()
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
cache,
)
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
cache.async_get_cache = mock_call_cache
# Call the function under test
await litellm.proxy.proxy_server.update_cache(
token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20
) # type: ignore
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()