From 89ee9fe4001a89bbe99f6924c18d41e56ccdec65 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 22 Dec 2023 15:10:38 +0530 Subject: [PATCH] fix(proxy_server.py): manage budget at user-level not key-level https://github.com/BerriAI/litellm/issues/1220 --- litellm/proxy/_types.py | 8 +- litellm/proxy/hooks/max_budget_limiter.py | 39 +++--- litellm/proxy/proxy_server.py | 138 ++++++++++++++++++---- litellm/proxy/schema.prisma | 7 +- litellm/proxy/utils.py | 107 ++++++++++++----- 5 files changed, 220 insertions(+), 79 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 76d37bddf..56fff3df4 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -121,7 +121,6 @@ class GenerateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} - max_budget: Optional[float] = None class UpdateKeyRequest(LiteLLMBase): key: str @@ -133,7 +132,6 @@ class UpdateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} - max_budget: Optional[float] = None class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth """ @@ -148,7 +146,6 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k max_parallel_requests: Optional[int] = None duration: str = "1h" metadata: dict = {} - max_budget: Optional[float] = None class GenerateKeyResponse(LiteLLMBase): key: str @@ -161,6 +158,11 @@ class _DeleteKeyObject(LiteLLMBase): class DeleteKeyRequest(LiteLLMBase): keys: List[_DeleteKeyObject] +class NewUserRequest(GenerateKeyRequest): + max_budget: Optional[float] = None + +class NewUserResponse(GenerateKeyResponse): + max_budget: Optional[float] = None class ConfigGeneralSettings(LiteLLMBase): """ diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index b2ffbeea8..e4dbdd5e7 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -4,6 +4,7 @@ from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException +import json, traceback class MaxBudgetLimiter(CustomLogger): # Class variables or attributes @@ -13,23 +14,27 @@ class MaxBudgetLimiter(CustomLogger): def print_verbose(self, print_statement): if litellm.set_verbose is True: print(print_statement) # noqa - async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): - self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") - api_key = user_api_key_dict.api_key - max_budget = user_api_key_dict.max_budget - curr_spend = user_api_key_dict.spend + try: + self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") + cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" + user_row = cache.get_cache(cache_key) + if user_row is None: # value not yet cached + return + max_budget = user_row["max_budget"] + curr_spend = user_row["spend"] - if api_key is None: - return - - if max_budget is None: - return - - if curr_spend is None: - return - - # CHECK IF REQUEST ALLOWED - if curr_spend >= max_budget: - raise HTTPException(status_code=429, detail="Max budget limit reached.") \ No newline at end of file + if max_budget is None: + return + + if curr_spend is None: + return + + # CHECK IF REQUEST ALLOWED + if curr_spend >= max_budget: + raise HTTPException(status_code=429, detail="Max budget limit reached.") + except HTTPException as e: + raise e + except Exception as e: + traceback.print_exc() \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 51e3fa104..49687c147 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -92,7 +92,8 @@ import litellm from litellm.proxy.utils import ( PrismaClient, get_instance_fn, - ProxyLogging + ProxyLogging, + _cache_user_row ) import pydantic from litellm.proxy._types import * @@ -258,8 +259,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if is_master_key_valid: return UserAPIKeyAuth(api_key=master_key) - if route.startswith("/key/") and not is_master_key_valid: - raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys") + if (route.startswith("/key/") or route.startswith("/user/")) and not is_master_key_valid: + raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users") if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error raise Exception("No connected db.") @@ -283,10 +284,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap llm_model_list = model_list print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called - api_key = valid_token.token - valid_token_dict = _get_pydantic_json_dict(valid_token) - valid_token_dict.pop("token", None) - return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) + pass else: try: data = await request.json() @@ -300,6 +298,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap api_key = valid_token.token valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) + """ + asyncio create task to update the user api key cache with the user db table as well + + This makes the user row data accessible to pre-api call hooks. + """ + asyncio.create_task(_cache_user_row(user_id=valid_token.user_id, cache=user_api_key_cache, db=prisma_client)) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: raise Exception(f"Invalid token") @@ -377,32 +381,57 @@ async def track_cost_callback( response_cost = litellm.completion_cost(completion_response=completion_response) print("streaming response_cost", response_cost) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) if user_api_key and prisma_client: await update_prisma_database(token=user_api_key, response_cost=response_cost) elif kwargs["stream"] == False: # for non streaming responses response_cost = litellm.completion_cost(completion_response=completion_response) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) if user_api_key and prisma_client: - await update_prisma_database(token=user_api_key, response_cost=response_cost) + await update_prisma_database(token=user_api_key, response_cost=response_cost, user_id=user_id) except Exception as e: print(f"error in tracking cost callback - {str(e)}") -async def update_prisma_database(token, response_cost): +async def update_prisma_database(token, response_cost, user_id=None): try: - print(f"Enters prisma db call, token: {token}") - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) - print(f"existing spend: {existing_spend_obj}") - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + print(f"Enters prisma db call, token: {token}; user_id: {user_id}") + ### UPDATE USER SPEND ### + async def _update_user_db(): + if user_id is None: + return + existing_spend_obj = await prisma_client.get_data(user_id=user_id) + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost - print(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data(token=token, data={"spend": new_spend}) + print(f"new cost: {new_spend}") + # Update the cost column for the given user id + await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) + + ### UPDATE KEY SPEND ### + async def _update_key_db(): + # Fetch the existing cost for the given token + existing_spend_obj = await prisma_client.get_data(token=token) + print(f"existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + print(f"new cost: {new_spend}") + # Update the cost column for the given token + await prisma_client.update_data(token=token, data={"spend": new_spend}) + tasks = [] + tasks.append(_update_user_db()) + tasks.append(_update_key_db()) + await asyncio.gather(*tasks) except Exception as e: print(f"Error updating Prisma database: {traceback.format_exc()}") pass @@ -682,7 +711,7 @@ async def generate_key_helper_fn(duration: Optional[str], except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} + return {"token": token, "expires": new_verification_token.expires, "user_id": user_id, "max_budget": max_budget} @@ -908,9 +937,11 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key data["model"] = user_model if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["headers"] = dict(request.headers) else: - data["metadata"] = {"user_api_key": user_api_key_dict.api_key} - + data["metadata"] = {"user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id} + data["metadata"]["headers"] = dict(request.headers) # override with user settings, these are params passed via cli if user_temperature: data["temperature"] = user_temperature @@ -993,10 +1024,12 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap if "metadata" in data: print(f'received metadata: {data["metadata"]}') data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["headers"] = dict(request.headers) else: data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli @@ -1092,9 +1125,12 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id else: data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in # check if non-openai/azure model called - e.g. for langchain integration @@ -1173,9 +1209,12 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth = if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id else: data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] ### CALL HOOKS ### - modify incoming data / reject request before calling the model @@ -1231,7 +1270,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ - # data = await request.json() data_json = data.json() # type: ignore response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) @@ -1287,6 +1325,52 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ detail={"error": str(e)}, ) +#### USER MANAGEMENT #### + +@router.post("/user/new", tags=["user management"], dependencies=[Depends(user_api_key_auth)], response_model=NewUserResponse) +async def new_user(data: NewUserRequest): + """ + Use this to create a new user with a budget. + + Returns user id, budget + new key. + + Parameters: + - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. + - max_budget: Optional[float] - Specify max budget for a given user. + - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) + - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + + Returns: + - key: (str) The generated api key + - expires: (datetime) Datetime object for when key expires. + - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. + - max_budget: (float|None) Max budget for given user. + """ + data_json = data.json() # type: ignore + response = await generate_key_helper_fn(**data_json) + return NewUserResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"], max_budget=response["max_budget"]) + + + +@router.post("/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]) +async def user_info(request: Request): + """ + [TODO]: Use this to get user information. (user row + all user key info) + """ + pass + +@router.post("/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]) +async def user_update(request: Request): + """ + [TODO]: Use this to update user budget + """ + pass + #### MODEL MANAGEMENT #### #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 @@ -1512,9 +1596,11 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use print(f'received metadata: {data["metadata"]}') data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id else: data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = dict(request.headers) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index e4acd13e5..5fa0ea008 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -7,6 +7,12 @@ generator client { provider = "prisma-client-py" } +model LiteLLM_UserTable { + user_id String @unique + max_budget Float? + spend Float @default(0.0) +} + // required for token gen model LiteLLM_VerificationToken { token String @unique @@ -18,5 +24,4 @@ model LiteLLM_VerificationToken { user_id String? max_parallel_requests Int? metadata Json @default("{}") - max_budget Float? } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 3592593d5..2dc62d664 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -165,26 +165,35 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def get_data(self, token: str, expires: Optional[Any]=None): + async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None): try: - # check if plain text or hash - hashed_token = token - if token.startswith("sk-"): - hashed_token = self.hash_token(token=token) - if expires: - response = await self.db.litellm_verificationtoken.find_first( + response = None + if token is not None: + # check if plain text or hash + hashed_token = token + if token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + if expires: + response = await self.db.litellm_verificationtoken.find_first( + where={ + "token": hashed_token, + "expires": {"gte": expires} # Check if the token is not expired + } + ) + else: + response = await self.db.litellm_verificationtoken.find_unique( where={ - "token": hashed_token, - "expires": {"gte": expires} # Check if the token is not expired + "token": hashed_token } ) - else: - response = await self.db.litellm_verificationtoken.find_unique( - where={ - "token": hashed_token - } - ) - return response + return response + elif user_id is not None: + response = await self.db.litellm_usertable.find_first( # type: ignore + where={ + "user_id": user_id, + } + ) + return response except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) raise e @@ -206,6 +215,7 @@ class PrismaClient: hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token + max_budget = db_data.pop("max_budget", None) new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ 'token': hashed_token, @@ -215,6 +225,16 @@ class PrismaClient: "update": {} # don't do anything if it already exists } ) + + new_user_row = await self.db.litellm_usertable.upsert( + where={ + 'user_id': data['user_id'] + }, + data={ + "create": {"user_id": data['user_id'], "max_budget": max_budget}, + "update": {} # don't do anything if it already exists + } + ) return new_verification_token except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) @@ -228,26 +248,37 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def update_data(self, token: str, data: dict): + async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None): """ Update existing data """ try: - print_verbose(f"token: {token}") - # check if plain text or hash - if token.startswith("sk-"): - token = self.hash_token(token=token) - db_data = self.jsonify_object(data=data) - db_data["token"] = token - response = await self.db.litellm_verificationtoken.update( - where={ - "token": token - }, - data={**db_data} # type: ignore - ) - print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") - return {"token": token, "data": db_data} + if token is not None: + print_verbose(f"token: {token}") + # check if plain text or hash + if token.startswith("sk-"): + token = self.hash_token(token=token) + db_data["token"] = token + response = await self.db.litellm_verificationtoken.update( + where={ + "token": token # type: ignore + }, + data={**db_data} # type: ignore + ) + print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") + return {"token": token, "data": db_data} + elif user_id is not None: + """ + If data['spend'] + data['user'], update the user table with spend info as well + """ + update_user_row = await self.db.litellm_usertable.update( + where={ + 'user_id': user_id # type: ignore + }, + data={**db_data} # type: ignore + ) + return {"user_id": user_id, "data": db_data} except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") @@ -342,4 +373,16 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: except Exception as e: raise e - \ No newline at end of file +### HELPER FUNCTIONS ### +async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): + """ + Check if a user_id exists in cache, + if not retrieve it. + """ + cache_key = f"{user_id}_user_api_key_user_id" + response = cache.get_cache(key=cache_key) + if response is None: # Cache miss + user_row = await db.get_data(user_id=user_id) + cache_value = user_row.model_dump_json() + cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes + return \ No newline at end of file