fix(proxy_server.py): track team spend for cached team object

fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
Krrish Dholakia 2024-06-18 17:10:12 -07:00
parent 5ad095ad9d
commit 6558abf845
5 changed files with 316 additions and 156 deletions

View file

@ -1,12 +1,26 @@
import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING
import secrets, subprocess
import hashlib, uuid
import warnings
import ast
import asyncio
import copy
import hashlib
import importlib
import inspect
import os
import platform
import random
import re
import secrets
import shutil
import subprocess
import sys
import threading
import time
import traceback
import uuid
import warnings
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
import requests
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -34,11 +48,12 @@ sys.path.insert(
) # Adds the parent directory to the system path - for litellm local dev
try:
import fastapi
import backoff
import yaml # type: ignore
import orjson
import logging
import backoff
import fastapi
import orjson
import yaml # type: ignore
from apscheduler.schedulers.asyncio import AsyncIOScheduler
except ImportError as e:
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
@ -85,60 +100,31 @@ def generate_feedback_box():
print() # noqa
import litellm
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.utils import (
PrismaClient,
DBClient,
get_instance_fn,
ProxyLogging,
_cache_user_row,
send_email,
get_logging_payload,
reset_budget,
hash_token,
html_form,
missing_keys_html_form,
_is_valid_team_configs,
_is_projected_spend_over_limit,
_get_projected_spend_over_limit,
update_spend,
encrypt_value,
decrypt_value,
get_error_message_str,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm import (
CreateBatchRequest,
RetrieveBatchRequest,
ListBatchRequest,
CancelBatchRequest,
CreateFileRequest,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import (
LiteLLM_Params,
Deployment,
updateDeployment,
ModelGroupInfo,
AssistantsTypedDict,
import litellm
from litellm import (
CancelBatchRequest,
CreateBatchRequest,
CreateFileRequest,
ListBatchRequest,
RetrieveBatchRequest,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -148,78 +134,105 @@ from litellm.proxy.auth.model_checks import (
get_team_models,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.auth.auth_checks import (
common_checks,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
router as spend_management_router,
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
_duration_in_seconds,
delete_verification_token,
generate_key_helper_fn,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
router as key_management_router,
_duration_in_seconds,
generate_key_helper_fn,
delete_verification_token,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
user_update,
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
router as spend_management_router,
)
from litellm.proxy.utils import (
DBClient,
PrismaClient,
ProxyLogging,
_cache_user_row,
_get_projected_spend_over_limit,
_is_projected_spend_over_limit,
_is_valid_team_configs,
decrypt_value,
encrypt_value,
get_error_message_str,
get_instance_fn,
get_logging_payload,
hash_token,
html_form,
missing_keys_html_form,
reset_budget,
send_email,
update_spend,
)
from litellm.router import (
AssistantsTypedDict,
Deployment,
LiteLLM_Params,
ModelGroupInfo,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.openai import HttpxBinaryResponseContent
try:
from litellm._version import version
except:
version = "0.0.0"
litellm.suppress_debug_info = True
from fastapi import (
FastAPI,
Request,
HTTPException,
status,
Path,
Depends,
Header,
Response,
Form,
UploadFile,
File,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
StreamingResponse,
FileResponse,
ORJSONResponse,
JSONResponse,
)
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.security.api_key import APIKeyHeader
import json
import logging
from typing import Union
from fastapi import (
Depends,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
ORJSONResponse,
RedirectResponse,
StreamingResponse,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.api_key import APIKeyHeader
from fastapi.staticfiles import StaticFiles
# import enterprise folder
try:
# when using litellm cli
@ -488,8 +501,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return
try:
from azure.keyvault.secrets import SecretClient
from azure.identity import ClientSecretCredential
from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
@ -655,6 +668,7 @@ async def _PROXY_track_cost_callback(
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
@ -897,6 +911,7 @@ async def update_cache(
token: Optional[str],
user_id: Optional[str],
end_user_id: Optional[str],
team_id: Optional[str],
response_cost: Optional[float],
):
"""
@ -993,6 +1008,7 @@ async def update_cache(
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
### UPDATE USER SPEND ###
async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
user_ids = [user_id]
@ -1054,6 +1070,7 @@ async def update_cache(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
)
### UPDATE END-USER SPEND ###
async def _update_end_user_cache():
if end_user_id is None or response_cost is None:
return
@ -1102,14 +1119,59 @@ async def update_cache(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
### UPDATE TEAM SPEND ###
async def _update_team_cache():
if team_id is None or response_cost is None:
return
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
if existing_spend_obj is None:
return
verbose_proxy_logger.debug(
f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
)
if existing_spend_obj is None:
existing_spend: Optional[float] = 0.0
else:
if isinstance(existing_spend_obj, dict):
existing_spend = existing_spend_obj["spend"]
else:
existing_spend = existing_spend_obj.spend
if existing_spend is None:
existing_spend = 0.0
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
else:
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
except Exception as e:
verbose_proxy_logger.error(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
if token is not None and response_cost is not None:
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
asyncio.create_task(_update_user_cache())
if user_id is not None:
asyncio.create_task(_update_user_cache())
if end_user_id is not None:
asyncio.create_task(_update_end_user_cache())
if team_id is not None:
asyncio.create_task(_update_team_cache())
def run_ollama_serve():
try:
@ -2297,25 +2359,27 @@ async def initialize(
user_model = model
user_debug = debug
if debug == True: # this needs to be first, so users can see Router init debugg
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
)
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug == True:
from litellm._logging import (
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
)
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
@ -2324,9 +2388,10 @@ async def initialize(
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting != None:
if litellm_log_setting.upper() == "INFO":
from litellm._logging import verbose_router_logger, verbose_proxy_logger
import logging
from litellm._logging import verbose_proxy_logger, verbose_router_logger
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_router_logger.setLevel(
@ -2336,9 +2401,10 @@ async def initialize(
level=logging.INFO
) # set proxy logs to info
elif litellm_log_setting.upper() == "DEBUG":
from litellm._logging import verbose_router_logger, verbose_proxy_logger
import logging
from litellm._logging import verbose_proxy_logger, verbose_router_logger
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to info
@ -7036,7 +7102,7 @@ async def google_login(request: Request):
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -7611,7 +7677,7 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")