forked from phoenix/litellm-mirror
fix(proxy_server.py): track team spend for cached team object
fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
parent
5ad095ad9d
commit
6558abf845
5 changed files with 316 additions and 156 deletions
|
@ -1,12 +1,26 @@
|
|||
import sys, os, platform, time, copy, re, asyncio, inspect
|
||||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING
|
||||
import secrets, subprocess
|
||||
import hashlib, uuid
|
||||
import warnings
|
||||
import ast
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
|
||||
|
||||
import requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -34,11 +48,12 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
|
||||
try:
|
||||
import fastapi
|
||||
import backoff
|
||||
import yaml # type: ignore
|
||||
import orjson
|
||||
import logging
|
||||
|
||||
import backoff
|
||||
import fastapi
|
||||
import orjson
|
||||
import yaml # type: ignore
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
||||
|
@ -85,60 +100,31 @@ def generate_feedback_box():
|
|||
print() # noqa
|
||||
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
DBClient,
|
||||
get_instance_fn,
|
||||
ProxyLogging,
|
||||
_cache_user_row,
|
||||
send_email,
|
||||
get_logging_payload,
|
||||
reset_budget,
|
||||
hash_token,
|
||||
html_form,
|
||||
missing_keys_html_form,
|
||||
_is_valid_team_configs,
|
||||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
encrypt_value,
|
||||
decrypt_value,
|
||||
get_error_message_str,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
|
||||
from litellm import (
|
||||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
ListBatchRequest,
|
||||
CancelBatchRequest,
|
||||
CreateFileRequest,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_secret_manager,
|
||||
load_aws_kms,
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import (
|
||||
LiteLLM_Params,
|
||||
Deployment,
|
||||
updateDeployment,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
|
||||
import litellm
|
||||
from litellm import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
CreateFileRequest,
|
||||
ListBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import (
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
allowed_routes_check,
|
||||
common_checks,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
log_to_opentelemetry,
|
||||
)
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
|
@ -148,78 +134,105 @@ from litellm.proxy.auth.model_checks import (
|
|||
get_team_models,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
common_checks,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
log_to_opentelemetry,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
|
||||
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
|
||||
|
||||
## Import All Misc routes here ##
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||
router as spend_management_router,
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
router as internal_user_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_duration_in_seconds,
|
||||
delete_verification_token,
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
router as key_management_router,
|
||||
_duration_in_seconds,
|
||||
generate_key_helper_fn,
|
||||
delete_verification_token,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
router as internal_user_router,
|
||||
user_update,
|
||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||
router as spend_management_router,
|
||||
)
|
||||
from litellm.proxy.utils import (
|
||||
DBClient,
|
||||
PrismaClient,
|
||||
ProxyLogging,
|
||||
_cache_user_row,
|
||||
_get_projected_spend_over_limit,
|
||||
_is_projected_spend_over_limit,
|
||||
_is_valid_team_configs,
|
||||
decrypt_value,
|
||||
encrypt_value,
|
||||
get_error_message_str,
|
||||
get_instance_fn,
|
||||
get_logging_payload,
|
||||
hash_token,
|
||||
html_form,
|
||||
missing_keys_html_form,
|
||||
reset_budget,
|
||||
send_email,
|
||||
update_spend,
|
||||
)
|
||||
from litellm.router import (
|
||||
AssistantsTypedDict,
|
||||
Deployment,
|
||||
LiteLLM_Params,
|
||||
ModelGroupInfo,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm.router import updateDeployment
|
||||
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
except:
|
||||
version = "0.0.0"
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
HTTPException,
|
||||
status,
|
||||
Path,
|
||||
Depends,
|
||||
Header,
|
||||
Response,
|
||||
Form,
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import (
|
||||
StreamingResponse,
|
||||
FileResponse,
|
||||
ORJSONResponse,
|
||||
JSONResponse,
|
||||
)
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
import json
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Path,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import (
|
||||
FileResponse,
|
||||
JSONResponse,
|
||||
ORJSONResponse,
|
||||
RedirectResponse,
|
||||
StreamingResponse,
|
||||
)
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# import enterprise folder
|
||||
try:
|
||||
# when using litellm cli
|
||||
|
@ -488,8 +501,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
return
|
||||
|
||||
try:
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
from azure.identity import ClientSecretCredential
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
# Set your Azure Key Vault URI
|
||||
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
|
||||
|
@ -655,6 +668,7 @@ async def _PROXY_track_cost_callback(
|
|||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
|
@ -897,6 +911,7 @@ async def update_cache(
|
|||
token: Optional[str],
|
||||
user_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
response_cost: Optional[float],
|
||||
):
|
||||
"""
|
||||
|
@ -993,6 +1008,7 @@ async def update_cache(
|
|||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||
|
||||
### UPDATE USER SPEND ###
|
||||
async def _update_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
user_ids = [user_id]
|
||||
|
@ -1054,6 +1070,7 @@ async def update_cache(
|
|||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
### UPDATE END-USER SPEND ###
|
||||
async def _update_end_user_cache():
|
||||
if end_user_id is None or response_cost is None:
|
||||
return
|
||||
|
@ -1102,14 +1119,59 @@ async def update_cache(
|
|||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
### UPDATE TEAM SPEND ###
|
||||
async def _update_team_cache():
|
||||
if team_id is None or response_cost is None:
|
||||
return
|
||||
|
||||
_id = "team_id:{}".format(team_id)
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
||||
await user_api_key_cache.async_get_cache(key=_id)
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
return
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend: Optional[float] = 0.0
|
||||
else:
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend = existing_spend_obj["spend"]
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
|
||||
if existing_spend is None:
|
||||
existing_spend = 0.0
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
if token is not None and response_cost is not None:
|
||||
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
|
||||
|
||||
asyncio.create_task(_update_user_cache())
|
||||
if user_id is not None:
|
||||
asyncio.create_task(_update_user_cache())
|
||||
|
||||
if end_user_id is not None:
|
||||
asyncio.create_task(_update_end_user_cache())
|
||||
|
||||
if team_id is not None:
|
||||
asyncio.create_task(_update_team_cache())
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
try:
|
||||
|
@ -2297,25 +2359,27 @@ async def initialize(
|
|||
user_model = model
|
||||
user_debug = debug
|
||||
if debug == True: # this needs to be first, so users can see Router init debugg
|
||||
from litellm._logging import (
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_logger,
|
||||
)
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
||||
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||
if detailed_debug == True:
|
||||
from litellm._logging import (
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_logger,
|
||||
)
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
verbose_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_router_logger,
|
||||
)
|
||||
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
||||
|
@ -2324,9 +2388,10 @@ async def initialize(
|
|||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||
if litellm_log_setting != None:
|
||||
if litellm_log_setting.upper() == "INFO":
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
|
||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
|
@ -2336,9 +2401,10 @@ async def initialize(
|
|||
level=logging.INFO
|
||||
) # set proxy logs to info
|
||||
elif litellm_log_setting.upper() == "DEBUG":
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
|
||||
verbose_router_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set router logs to info
|
||||
|
@ -7036,7 +7102,7 @@ async def google_login(request: Request):
|
|||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -7611,7 +7677,7 @@ async def auth_callback(request: Request):
|
|||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue