forked from phoenix/litellm-mirror
fix(handle_jwt.py): track spend for user using jwt auth
This commit is contained in:
parent
ca970a90c4
commit
90e17b5422
5 changed files with 285 additions and 38 deletions
|
@ -599,6 +599,8 @@ class LiteLLM_UserTable(LiteLLMBase):
|
|||
model_spend: Optional[Dict] = {}
|
||||
user_email: Optional[str]
|
||||
models: list = []
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
|
@ -617,6 +619,7 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
|||
blocked: bool
|
||||
alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
|
|
84
litellm/proxy/auth/auth_checks.py
Normal file
84
litellm/proxy/auth/auth_checks.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# What is this?
|
||||
## Common auth checks between jwt + key based auth
|
||||
"""
|
||||
Got Valid Token from Cache, DB
|
||||
Run checks for:
|
||||
|
||||
1. If user can call model
|
||||
2. If user is in budget
|
||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
||||
from typing import Optional
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.caching import DualCache
|
||||
|
||||
|
||||
def common_checks(
|
||||
request_body: dict,
|
||||
user_object: LiteLLM_UserTable,
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
) -> bool:
|
||||
_model = request_body.get("model", None)
|
||||
# 1. If user can call model
|
||||
if (
|
||||
_model is not None
|
||||
and len(user_object.models) > 0
|
||||
and _model not in user_object.models
|
||||
):
|
||||
raise Exception(
|
||||
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
|
||||
)
|
||||
# 2. If user is in budget
|
||||
if (
|
||||
user_object.max_budget is not None
|
||||
and user_object.spend > user_object.max_budget
|
||||
):
|
||||
raise Exception(
|
||||
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
|
||||
)
|
||||
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
||||
raise Exception(
|
||||
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def get_end_user_object(
|
||||
end_user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
) -> Optional[LiteLLM_EndUserTable]:
|
||||
"""
|
||||
Returns end user object, if in db.
|
||||
|
||||
Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise Exception("No db connected")
|
||||
|
||||
if end_user_id is None:
|
||||
return None
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_EndUserTable(**cached_user_obj)
|
||||
elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
|
||||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_endusertable.find_unique(
|
||||
where={"user_id": end_user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_EndUserTable(**response.dict())
|
||||
except Exception as e: # if end-user not in db
|
||||
return None
|
|
@ -8,23 +8,27 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
|||
|
||||
import httpx
|
||||
import jwt
|
||||
|
||||
print(jwt.__version__) # noqa
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
import json
|
||||
import os
|
||||
from litellm.proxy._types import LiteLLMProxyRoles
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient()
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
)
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
async def close(self):
|
||||
# Close the client when you're done with it
|
||||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
|
@ -47,10 +51,27 @@ class HTTPHandler:
|
|||
|
||||
|
||||
class JWTHandler:
|
||||
"""
|
||||
- treat the sub id passed in as the user id
|
||||
- return an error if id making request doesn't exist in proxy user table
|
||||
- track spend against the user id
|
||||
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
prisma_client: Optional[PrismaClient]
|
||||
user_api_key_cache: DualCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
self.http_handler = HTTPHandler()
|
||||
|
||||
def update_environment(
|
||||
self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
@ -67,6 +88,46 @@ class JWTHandler:
|
|||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
team_id = token["azp"]
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
- if valid, return LiteLLM_UserTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
|
@ -78,8 +139,10 @@ class JWTHandler:
|
|||
async def auth_jwt(self, token: str) -> dict:
|
||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||
|
||||
async with self.http_handler as http:
|
||||
response = await http.get(keys_url)
|
||||
if keys_url is None:
|
||||
raise Exception("Missing JWT Public Key URL from environment.")
|
||||
|
||||
response = await self.http_handler.get(keys_url)
|
||||
|
||||
keys = response.json()["keys"]
|
||||
|
||||
|
@ -113,3 +176,6 @@ class JWTHandler:
|
|||
raise Exception(f"Validation fails: {str(e)}")
|
||||
|
||||
raise Exception("Invalid JWT Submitted")
|
||||
|
||||
async def close(self):
|
||||
await self.http_handler.close()
|
||||
|
|
|
@ -107,6 +107,7 @@ from litellm.caching import DualCache
|
|||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -360,18 +361,54 @@ async def user_api_key_auth(
|
|||
user_id = jwt_handler.get_user_id(
|
||||
token=valid_token, default_value=litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
end_user_object = None
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
||||
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
||||
if (
|
||||
request_data.get("user", None)
|
||||
and request_data["user"] != user_object.user_id
|
||||
):
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=request_data["user"],
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the end-user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=request_data["user"], value=end_user_object
|
||||
)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id, value=user_object
|
||||
)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
_user_api_key_obj = UserAPIKeyAuth(
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=user_id,
|
||||
)
|
||||
user_api_key_cache.set_cache(
|
||||
key=hash_token(api_key), value=_user_api_key_obj
|
||||
else:
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
user_id=user_object.user_id,
|
||||
tpm_limit=user_object.tpm_limit,
|
||||
rpm_limit=user_object.rpm_limit,
|
||||
models=user_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
|
||||
return _user_api_key_obj
|
||||
else:
|
||||
raise Exception("Invalid key error!")
|
||||
#### ELSE ####
|
||||
|
@ -438,7 +475,7 @@ async def user_api_key_auth(
|
|||
user_role="proxy_admin",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
user_api_key_cache.set_cache(
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=hash_token(master_key), value=_user_api_key_obj
|
||||
)
|
||||
|
||||
|
@ -603,7 +640,7 @@ async def user_api_key_auth(
|
|||
query_type="find_all",
|
||||
)
|
||||
for _id in user_id_information:
|
||||
user_api_key_cache.set_cache(
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id["user_id"], value=_id, ttl=600
|
||||
)
|
||||
if custom_db_client is not None:
|
||||
|
@ -791,7 +828,9 @@ async def user_api_key_auth(
|
|||
api_key = valid_token.token
|
||||
|
||||
# Add hashed token to cache
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600)
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=api_key, value=valid_token, ttl=600
|
||||
)
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
"""
|
||||
|
@ -1073,7 +1112,10 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
|
||||
await update_cache(
|
||||
token=user_api_key, user_id=user_id, response_cost=response_cost
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
else:
|
||||
raise Exception("User API key missing from custom callback.")
|
||||
|
@ -1348,9 +1390,10 @@ async def update_database(
|
|||
|
||||
|
||||
async def update_cache(
|
||||
token,
|
||||
user_id,
|
||||
response_cost,
|
||||
token: Optional[str],
|
||||
user_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
response_cost: Optional[float],
|
||||
):
|
||||
"""
|
||||
Use this to update the cache with new user spend.
|
||||
|
@ -1365,12 +1408,17 @@ async def update_cache(
|
|||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}")
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
||||
f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_cache: existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
existing_spend_obj = LiteLLM_VerificationTokenView()
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
|
@ -1426,18 +1474,7 @@ async def update_cache(
|
|||
|
||||
async def _update_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
if existing_token_obj is None:
|
||||
return
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
||||
|
||||
user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
|
||||
try:
|
||||
for _id in user_ids:
|
||||
# Fetch the existing cost for the given user
|
||||
|
@ -1483,9 +1520,59 @@ async def update_cache(
|
|||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def _update_end_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
_id = end_user_id
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_EndUserTable(
|
||||
user_id=_id,
|
||||
spend=0,
|
||||
blocked=False,
|
||||
litellm_budget_table=LiteLLM_BudgetTable(
|
||||
max_budget=max_user_budget
|
||||
),
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend = existing_spend_obj["spend"]
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
if token is not None:
|
||||
asyncio.create_task(_update_key_cache())
|
||||
|
||||
asyncio.create_task(_update_user_cache())
|
||||
|
||||
if end_user_id is not None:
|
||||
asyncio.create_task(_update_end_user_cache())
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
try:
|
||||
|
@ -2587,6 +2674,11 @@ async def startup_event():
|
|||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
|
@ -7750,6 +7842,8 @@ async def shutdown_event():
|
|||
if litellm.cache is not None:
|
||||
await litellm.cache.disconnect()
|
||||
|
||||
await jwt_handler.close()
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue