fix(handle_jwt.py): track spend for user using jwt auth

This commit is contained in:
Krrish Dholakia 2024-03-20 10:55:52 -07:00
parent ca970a90c4
commit 90e17b5422
5 changed files with 285 additions and 38 deletions

View file

@ -599,6 +599,8 @@ class LiteLLM_UserTable(LiteLLMBase):
model_spend: Optional[Dict] = {} model_spend: Optional[Dict] = {}
user_email: Optional[str] user_email: Optional[str]
models: list = [] models: list = []
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
@ -617,6 +619,7 @@ class LiteLLM_EndUserTable(LiteLLMBase):
blocked: bool blocked: bool
alias: Optional[str] = None alias: Optional[str] = None
spend: float = 0.0 spend: float = 0.0
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):

View file

@ -0,0 +1,84 @@
# What is this?
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
from typing import Optional
from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache
def common_checks(
request_body: dict,
user_object: LiteLLM_UserTable,
end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool:
_model = request_body.get("model", None)
# 1. If user can call model
if (
_model is not None
and len(user_object.models) > 0
and _model not in user_object.models
):
raise Exception(
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
)
# 2. If user is in budget
if (
user_object.max_budget is not None
and user_object.spend > user_object.max_budget
):
raise Exception(
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
)
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception(
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
)
return True
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object, if in db.
Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
"""
if prisma_client is None:
raise Exception("No db connected")
if end_user_id is None:
return None
# check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_EndUserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
return cached_user_obj
# else, check db
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id}
)
if response is None:
raise Exception
return LiteLLM_EndUserTable(**response.dict())
except Exception as e: # if end-user not in db
return None

View file

@ -8,23 +8,27 @@ JWT token must have 'litellm_proxy_admin' in scope.
import httpx import httpx
import jwt import jwt
print(jwt.__version__) # noqa
from jwt.algorithms import RSAAlgorithm from jwt.algorithms import RSAAlgorithm
import json import json
import os import os
from litellm.proxy._types import LiteLLMProxyRoles from litellm.caching import DualCache
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient
from typing import Optional from typing import Optional
class HTTPHandler: class HTTPHandler:
def __init__(self): def __init__(self, concurrent_limit=1000):
self.client = httpx.AsyncClient() # Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
)
)
async def __aenter__(self): async def close(self):
return self # Close the client when you're done with it
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose() await self.client.aclose()
async def get( async def get(
@ -47,10 +51,27 @@ class HTTPHandler:
class JWTHandler: class JWTHandler:
"""
- treat the sub id passed in as the user id
- return an error if id making request doesn't exist in proxy user table
- track spend against the user id
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
"""
def __init__(self) -> None: prisma_client: Optional[PrismaClient]
user_api_key_cache: DualCache
def __init__(
self,
) -> None:
self.http_handler = HTTPHandler() self.http_handler = HTTPHandler()
def update_environment(
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
def is_jwt(self, token: str): def is_jwt(self, token: str):
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
@ -67,6 +88,46 @@ class JWTHandler:
user_id = default_value user_id = default_value
return user_id return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
team_id = token["azp"]
except KeyError:
team_id = default_value
return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
# Assuming the scopes are stored in 'scope' claim and are space-separated # Assuming the scopes are stored in 'scope' claim and are space-separated
@ -78,8 +139,10 @@ class JWTHandler:
async def auth_jwt(self, token: str) -> dict: async def auth_jwt(self, token: str) -> dict:
keys_url = os.getenv("JWT_PUBLIC_KEY_URL") keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
async with self.http_handler as http: if keys_url is None:
response = await http.get(keys_url) raise Exception("Missing JWT Public Key URL from environment.")
response = await self.http_handler.get(keys_url)
keys = response.json()["keys"] keys = response.json()["keys"]
@ -113,3 +176,6 @@ class JWTHandler:
raise Exception(f"Validation fails: {str(e)}") raise Exception(f"Validation fails: {str(e)}")
raise Exception("Invalid JWT Submitted") raise Exception("Invalid JWT Submitted")
async def close(self):
await self.http_handler.close()

View file

@ -107,6 +107,7 @@ from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
try: try:
from litellm._version import version from litellm._version import version
@ -360,18 +361,54 @@ async def user_api_key_auth(
user_id = jwt_handler.get_user_id( user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name token=valid_token, default_value=litellm_proxy_admin_name
) )
end_user_object = None
# get the request body
request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id)
if (
request_data.get("user", None)
and request_data["user"] != user_object.user_id
):
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object
)
# run through common checks
_ = common_checks(
request_body=request_data,
user_object=user_object,
end_user_object=end_user_object,
)
# save user object in cache
await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object
)
# if admin return # if admin return
if is_admin: if is_admin:
_user_api_key_obj = UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=api_key, api_key=api_key,
user_role="proxy_admin", user_role="proxy_admin",
user_id=user_id, user_id=user_id,
) )
user_api_key_cache.set_cache( else:
key=hash_token(api_key), value=_user_api_key_obj # return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner",
) )
return _user_api_key_obj
else: else:
raise Exception("Invalid key error!") raise Exception("Invalid key error!")
#### ELSE #### #### ELSE ####
@ -438,7 +475,7 @@ async def user_api_key_auth(
user_role="proxy_admin", user_role="proxy_admin",
user_id=litellm_proxy_admin_name, user_id=litellm_proxy_admin_name,
) )
user_api_key_cache.set_cache( await user_api_key_cache.async_set_cache(
key=hash_token(master_key), value=_user_api_key_obj key=hash_token(master_key), value=_user_api_key_obj
) )
@ -603,7 +640,7 @@ async def user_api_key_auth(
query_type="find_all", query_type="find_all",
) )
for _id in user_id_information: for _id in user_id_information:
user_api_key_cache.set_cache( await user_api_key_cache.async_set_cache(
key=_id["user_id"], value=_id, ttl=600 key=_id["user_id"], value=_id, ttl=600
) )
if custom_db_client is not None: if custom_db_client is not None:
@ -791,7 +828,9 @@ async def user_api_key_auth(
api_key = valid_token.token api_key = valid_token.token
# Add hashed token to cache # Add hashed token to cache
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600) await user_api_key_cache.async_set_cache(
key=api_key, value=valid_token, ttl=600
)
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
""" """
@ -1073,7 +1112,10 @@ async def _PROXY_track_cost_callback(
) )
await update_cache( await update_cache(
token=user_api_key, user_id=user_id, response_cost=response_cost token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
) )
else: else:
raise Exception("User API key missing from custom callback.") raise Exception("User API key missing from custom callback.")
@ -1348,9 +1390,10 @@ async def update_database(
async def update_cache( async def update_cache(
token, token: Optional[str],
user_id, user_id: Optional[str],
response_cost, end_user_id: Optional[str],
response_cost: Optional[float],
): ):
""" """
Use this to update the cache with new user spend. Use this to update the cache with new user spend.
@ -1365,12 +1408,17 @@ async def update_cache(
hashed_token = hash_token(token=token) hashed_token = hash_token(token=token)
else: else:
hashed_token = token hashed_token = token
verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}")
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token) existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"_update_key_db: existing spend: {existing_spend_obj}" f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
)
verbose_proxy_logger.debug(
f"_update_key_cache: existing spend: {existing_spend_obj}"
) )
if existing_spend_obj is None: if existing_spend_obj is None:
existing_spend = 0 existing_spend = 0
existing_spend_obj = LiteLLM_VerificationTokenView()
else: else:
existing_spend = existing_spend_obj.spend existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost # Calculate the new cost by adding the existing cost and response_cost
@ -1426,18 +1474,7 @@ async def update_cache(
async def _update_user_cache(): async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY ## UPDATE CACHE FOR USER ID + GLOBAL PROXY
end_user_id = None user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
if existing_token_obj is None:
return
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
try: try:
for _id in user_ids: for _id in user_ids:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
@ -1483,9 +1520,59 @@ async def update_cache(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
) )
asyncio.create_task(_update_key_cache()) async def _update_end_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
_id = end_user_id
try:
# Fetch the existing cost for the given user
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
existing_spend_obj = LiteLLM_EndUserTable(
user_id=_id,
spend=0,
blocked=False,
litellm_budget_table=LiteLLM_BudgetTable(
max_budget=max_user_budget
),
)
verbose_proxy_logger.debug(
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
if isinstance(existing_spend_obj, dict):
existing_spend = existing_spend_obj["spend"]
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
else:
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
except Exception as e:
verbose_proxy_logger.debug(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
if token is not None:
asyncio.create_task(_update_key_cache())
asyncio.create_task(_update_user_cache()) asyncio.create_task(_update_user_cache())
if end_user_id is not None:
asyncio.create_task(_update_end_user_cache())
def run_ollama_serve(): def run_ollama_serve():
try: try:
@ -2587,6 +2674,11 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ##
jwt_handler.update_environment(
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
)
if use_background_health_checks: if use_background_health_checks:
asyncio.create_task( asyncio.create_task(
_run_background_health_check() _run_background_health_check()
@ -7750,6 +7842,8 @@ async def shutdown_event():
if litellm.cache is not None: if litellm.cache is not None:
await litellm.cache.disconnect() await litellm.cache.disconnect()
await jwt_handler.close()
## RESET CUSTOM VARIABLES ## ## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables() cleanup_router_config_variables()