forked from phoenix/litellm-mirror
Merge pull request #2606 from BerriAI/litellm_jwt_auth_updates
fix(handle_jwt.py): track spend for user using jwt auth
This commit is contained in:
commit
007d439017
6 changed files with 296 additions and 40 deletions
|
@ -841,6 +841,17 @@ class DualCache(BaseCache):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -599,6 +599,8 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
model_spend: Optional[Dict] = {}
|
model_spend: Optional[Dict] = {}
|
||||||
user_email: Optional[str]
|
user_email: Optional[str]
|
||||||
models: list = []
|
models: list = []
|
||||||
|
tpm_limit: Optional[int] = None
|
||||||
|
rpm_limit: Optional[int] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
|
@ -617,6 +619,7 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
||||||
blocked: bool
|
blocked: bool
|
||||||
alias: Optional[str] = None
|
alias: Optional[str] = None
|
||||||
spend: float = 0.0
|
spend: float = 0.0
|
||||||
|
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
|
|
84
litellm/proxy/auth/auth_checks.py
Normal file
84
litellm/proxy/auth/auth_checks.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# What is this?
|
||||||
|
## Common auth checks between jwt + key based auth
|
||||||
|
"""
|
||||||
|
Got Valid Token from Cache, DB
|
||||||
|
Run checks for:
|
||||||
|
|
||||||
|
1. If user can call model
|
||||||
|
2. If user is in budget
|
||||||
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
|
"""
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
||||||
|
from typing import Optional
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
def common_checks(
|
||||||
|
request_body: dict,
|
||||||
|
user_object: LiteLLM_UserTable,
|
||||||
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
|
) -> bool:
|
||||||
|
_model = request_body.get("model", None)
|
||||||
|
# 1. If user can call model
|
||||||
|
if (
|
||||||
|
_model is not None
|
||||||
|
and len(user_object.models) > 0
|
||||||
|
and _model not in user_object.models
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
|
||||||
|
)
|
||||||
|
# 2. If user is in budget
|
||||||
|
if (
|
||||||
|
user_object.max_budget is not None
|
||||||
|
and user_object.spend > user_object.max_budget
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
|
||||||
|
)
|
||||||
|
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
|
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||||
|
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def get_end_user_object(
|
||||||
|
end_user_id: Optional[str],
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
) -> Optional[LiteLLM_EndUserTable]:
|
||||||
|
"""
|
||||||
|
Returns end user object, if in db.
|
||||||
|
|
||||||
|
Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
|
||||||
|
"""
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("No db connected")
|
||||||
|
|
||||||
|
if end_user_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id)
|
||||||
|
if cached_user_obj is not None:
|
||||||
|
if isinstance(cached_user_obj, dict):
|
||||||
|
return LiteLLM_EndUserTable(**cached_user_obj)
|
||||||
|
elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
|
||||||
|
return cached_user_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await prisma_client.db.litellm_endusertable.find_unique(
|
||||||
|
where={"user_id": end_user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return LiteLLM_EndUserTable(**response.dict())
|
||||||
|
except Exception as e: # if end-user not in db
|
||||||
|
return None
|
|
@ -8,23 +8,27 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
print(jwt.__version__) # noqa
|
|
||||||
from jwt.algorithms import RSAAlgorithm
|
from jwt.algorithms import RSAAlgorithm
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from litellm.proxy._types import LiteLLMProxyRoles
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class HTTPHandler:
|
class HTTPHandler:
|
||||||
def __init__(self):
|
def __init__(self, concurrent_limit=1000):
|
||||||
self.client = httpx.AsyncClient()
|
# Create a client with a connection pool
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=concurrent_limit,
|
||||||
|
max_keepalive_connections=concurrent_limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def close(self):
|
||||||
return self
|
# Close the client when you're done with it
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
await self.client.aclose()
|
await self.client.aclose()
|
||||||
|
|
||||||
async def get(
|
async def get(
|
||||||
|
@ -47,10 +51,27 @@ class HTTPHandler:
|
||||||
|
|
||||||
|
|
||||||
class JWTHandler:
|
class JWTHandler:
|
||||||
|
"""
|
||||||
|
- treat the sub id passed in as the user id
|
||||||
|
- return an error if id making request doesn't exist in proxy user table
|
||||||
|
- track spend against the user id
|
||||||
|
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
prisma_client: Optional[PrismaClient]
|
||||||
|
user_api_key_cache: DualCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
self.http_handler = HTTPHandler()
|
self.http_handler = HTTPHandler()
|
||||||
|
|
||||||
|
def update_environment(
|
||||||
|
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
|
||||||
|
) -> None:
|
||||||
|
self.prisma_client = prisma_client
|
||||||
|
self.user_api_key_cache = user_api_key_cache
|
||||||
|
|
||||||
def is_jwt(self, token: str):
|
def is_jwt(self, token: str):
|
||||||
parts = token.split(".")
|
parts = token.split(".")
|
||||||
return len(parts) == 3
|
return len(parts) == 3
|
||||||
|
@ -67,6 +88,46 @@ class JWTHandler:
|
||||||
user_id = default_value
|
user_id = default_value
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
team_id = token["azp"]
|
||||||
|
except KeyError:
|
||||||
|
team_id = default_value
|
||||||
|
return team_id
|
||||||
|
|
||||||
|
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
|
"""
|
||||||
|
- Check if user id in proxy User Table
|
||||||
|
- if valid, return LiteLLM_UserTable object with defined limits
|
||||||
|
- if not, then raise an error
|
||||||
|
"""
|
||||||
|
if self.prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if in cache
|
||||||
|
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||||
|
if cached_user_obj is not None:
|
||||||
|
if isinstance(cached_user_obj, dict):
|
||||||
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||||
|
return cached_user_obj
|
||||||
|
# else, check db
|
||||||
|
try:
|
||||||
|
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
where={"user_id": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
return LiteLLM_UserTable(**response.dict())
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||||
|
)
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||||
|
@ -78,8 +139,10 @@ class JWTHandler:
|
||||||
async def auth_jwt(self, token: str) -> dict:
|
async def auth_jwt(self, token: str) -> dict:
|
||||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||||
|
|
||||||
async with self.http_handler as http:
|
if keys_url is None:
|
||||||
response = await http.get(keys_url)
|
raise Exception("Missing JWT Public Key URL from environment.")
|
||||||
|
|
||||||
|
response = await self.http_handler.get(keys_url)
|
||||||
|
|
||||||
keys = response.json()["keys"]
|
keys = response.json()["keys"]
|
||||||
|
|
||||||
|
@ -113,3 +176,6 @@ class JWTHandler:
|
||||||
raise Exception(f"Validation fails: {str(e)}")
|
raise Exception(f"Validation fails: {str(e)}")
|
||||||
|
|
||||||
raise Exception("Invalid JWT Submitted")
|
raise Exception("Invalid JWT Submitted")
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.http_handler.close()
|
||||||
|
|
|
@ -107,6 +107,7 @@ from litellm.caching import DualCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -360,20 +361,54 @@ async def user_api_key_auth(
|
||||||
user_id = jwt_handler.get_user_id(
|
user_id = jwt_handler.get_user_id(
|
||||||
token=valid_token, default_value=litellm_proxy_admin_name
|
token=valid_token, default_value=litellm_proxy_admin_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_user_object = None
|
||||||
|
# get the request body
|
||||||
|
request_data = await _read_request_body(request=request)
|
||||||
|
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
||||||
|
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
||||||
|
if (
|
||||||
|
request_data.get("user", None)
|
||||||
|
and request_data["user"] != user_object.user_id
|
||||||
|
):
|
||||||
|
# get the end-user object
|
||||||
|
end_user_object = await get_end_user_object(
|
||||||
|
end_user_id=request_data["user"],
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
# save the end-user object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key=request_data["user"], value=end_user_object
|
||||||
|
)
|
||||||
|
|
||||||
|
# run through common checks
|
||||||
|
_ = common_checks(
|
||||||
|
request_body=request_data,
|
||||||
|
user_object=user_object,
|
||||||
|
end_user_object=end_user_object,
|
||||||
|
)
|
||||||
|
# save user object in cache
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key=user_object.user_id, value=user_object
|
||||||
|
)
|
||||||
# if admin return
|
# if admin return
|
||||||
if is_admin:
|
if is_admin:
|
||||||
_user_api_key_obj = UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
user_role="proxy_admin",
|
user_role="proxy_admin",
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
user_api_key_cache.set_cache(
|
else:
|
||||||
key=hash_token(api_key), value=_user_api_key_obj
|
# return UserAPIKeyAuth object
|
||||||
|
return UserAPIKeyAuth(
|
||||||
|
api_key=None,
|
||||||
|
user_id=user_object.user_id,
|
||||||
|
tpm_limit=user_object.tpm_limit,
|
||||||
|
rpm_limit=user_object.rpm_limit,
|
||||||
|
models=user_object.models,
|
||||||
|
user_role="app_owner",
|
||||||
)
|
)
|
||||||
|
|
||||||
return _user_api_key_obj
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid key error!")
|
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -438,7 +473,7 @@ async def user_api_key_auth(
|
||||||
user_role="proxy_admin",
|
user_role="proxy_admin",
|
||||||
user_id=litellm_proxy_admin_name,
|
user_id=litellm_proxy_admin_name,
|
||||||
)
|
)
|
||||||
user_api_key_cache.set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=hash_token(master_key), value=_user_api_key_obj
|
key=hash_token(master_key), value=_user_api_key_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -603,7 +638,7 @@ async def user_api_key_auth(
|
||||||
query_type="find_all",
|
query_type="find_all",
|
||||||
)
|
)
|
||||||
for _id in user_id_information:
|
for _id in user_id_information:
|
||||||
user_api_key_cache.set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=_id["user_id"], value=_id, ttl=600
|
key=_id["user_id"], value=_id, ttl=600
|
||||||
)
|
)
|
||||||
if custom_db_client is not None:
|
if custom_db_client is not None:
|
||||||
|
@ -791,7 +826,9 @@ async def user_api_key_auth(
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
# Add hashed token to cache
|
# Add hashed token to cache
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600)
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key=api_key, value=valid_token, ttl=600
|
||||||
|
)
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
"""
|
"""
|
||||||
|
@ -1073,7 +1110,10 @@ async def _PROXY_track_cost_callback(
|
||||||
)
|
)
|
||||||
|
|
||||||
await update_cache(
|
await update_cache(
|
||||||
token=user_api_key, user_id=user_id, response_cost=response_cost
|
token=user_api_key,
|
||||||
|
user_id=user_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
response_cost=response_cost,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("User API key missing from custom callback.")
|
raise Exception("User API key missing from custom callback.")
|
||||||
|
@ -1348,9 +1388,10 @@ async def update_database(
|
||||||
|
|
||||||
|
|
||||||
async def update_cache(
|
async def update_cache(
|
||||||
token,
|
token: Optional[str],
|
||||||
user_id,
|
user_id: Optional[str],
|
||||||
response_cost,
|
end_user_id: Optional[str],
|
||||||
|
response_cost: Optional[float],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Use this to update the cache with new user spend.
|
Use this to update the cache with new user spend.
|
||||||
|
@ -1365,12 +1406,17 @@ async def update_cache(
|
||||||
hashed_token = hash_token(token=token)
|
hashed_token = hash_token(token=token)
|
||||||
else:
|
else:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
|
verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}")
|
||||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"_update_key_cache: existing spend: {existing_spend_obj}"
|
||||||
)
|
)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
existing_spend = 0
|
existing_spend = 0
|
||||||
|
existing_spend_obj = LiteLLM_VerificationTokenView()
|
||||||
else:
|
else:
|
||||||
existing_spend = existing_spend_obj.spend
|
existing_spend = existing_spend_obj.spend
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
@ -1426,18 +1472,7 @@ async def update_cache(
|
||||||
|
|
||||||
async def _update_user_cache():
|
async def _update_user_cache():
|
||||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||||
end_user_id = None
|
user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
|
||||||
if isinstance(token, str) and token.startswith("sk-"):
|
|
||||||
hashed_token = hash_token(token=token)
|
|
||||||
else:
|
|
||||||
hashed_token = token
|
|
||||||
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
|
||||||
if existing_token_obj is None:
|
|
||||||
return
|
|
||||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
|
||||||
end_user_id = user_id
|
|
||||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _id in user_ids:
|
for _id in user_ids:
|
||||||
# Fetch the existing cost for the given user
|
# Fetch the existing cost for the given user
|
||||||
|
@ -1483,9 +1518,59 @@ async def update_cache(
|
||||||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.create_task(_update_key_cache())
|
async def _update_end_user_cache():
|
||||||
|
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||||
|
_id = end_user_id
|
||||||
|
try:
|
||||||
|
# Fetch the existing cost for the given user
|
||||||
|
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||||
|
existing_spend = 0
|
||||||
|
max_user_budget = None
|
||||||
|
if litellm.max_user_budget is not None:
|
||||||
|
max_user_budget = litellm.max_user_budget
|
||||||
|
existing_spend_obj = LiteLLM_EndUserTable(
|
||||||
|
user_id=_id,
|
||||||
|
spend=0,
|
||||||
|
blocked=False,
|
||||||
|
litellm_budget_table=LiteLLM_BudgetTable(
|
||||||
|
max_budget=max_user_budget
|
||||||
|
),
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
||||||
|
)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend = existing_spend_obj["spend"]
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
# Update the cost column for the given user
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend_obj["spend"] = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||||
|
else:
|
||||||
|
existing_spend_obj.spend = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if token is not None:
|
||||||
|
asyncio.create_task(_update_key_cache())
|
||||||
|
|
||||||
asyncio.create_task(_update_user_cache())
|
asyncio.create_task(_update_user_cache())
|
||||||
|
|
||||||
|
if end_user_id is not None:
|
||||||
|
asyncio.create_task(_update_end_user_cache())
|
||||||
|
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
try:
|
try:
|
||||||
|
@ -2587,6 +2672,11 @@ async def startup_event():
|
||||||
|
|
||||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
|
## JWT AUTH ##
|
||||||
|
jwt_handler.update_environment(
|
||||||
|
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
|
||||||
|
)
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_run_background_health_check()
|
_run_background_health_check()
|
||||||
|
@ -7750,6 +7840,8 @@ async def shutdown_event():
|
||||||
if litellm.cache is not None:
|
if litellm.cache is not None:
|
||||||
await litellm.cache.disconnect()
|
await litellm.cache.disconnect()
|
||||||
|
|
||||||
|
await jwt_handler.close()
|
||||||
|
|
||||||
## RESET CUSTOM VARIABLES ##
|
## RESET CUSTOM VARIABLES ##
|
||||||
cleanup_router_config_variables()
|
cleanup_router_config_variables()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue