fix(proxy_server.py): manage budget at user-level not key-level

https://github.com/BerriAI/litellm/issues/1220
This commit is contained in:
Krrish Dholakia 2023-12-22 15:10:38 +05:30
parent 979575a2a6
commit 89ee9fe400
5 changed files with 220 additions and 79 deletions

View file

@ -92,7 +92,8 @@ import litellm
from litellm.proxy.utils import (
PrismaClient,
get_instance_fn,
ProxyLogging
ProxyLogging,
_cache_user_row
)
import pydantic
from litellm.proxy._types import *
@ -258,8 +259,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key)
if route.startswith("/key/") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
if (route.startswith("/key/") or route.startswith("/user/")) and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
@ -283,10 +284,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
pass
else:
try:
data = await request.json()
@ -300,6 +298,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
"""
asyncio create task to update the user api key cache with the user db table as well
This makes the user row data accessible to pre-api call hooks.
"""
asyncio.create_task(_cache_user_row(user_id=valid_token.user_id, cache=user_api_key_cache, db=prisma_client))
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
@ -377,32 +381,57 @@ async def track_cost_callback(
response_cost = litellm.completion_cost(completion_response=completion_response)
print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(completion_response=completion_response)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
await update_prisma_database(token=user_api_key, response_cost=response_cost, user_id=user_id)
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
async def update_prisma_database(token, response_cost, user_id=None):
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
print(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"Enters prisma db call, token: {token}; user_id: {user_id}")
### UPDATE USER SPEND ###
async def _update_user_db():
if user_id is None:
return
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"new cost: {new_spend}")
# Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
### UPDATE KEY SPEND ###
async def _update_key_db():
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
print(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
await asyncio.gather(*tasks)
except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}")
pass
@ -682,7 +711,7 @@ async def generate_key_helper_fn(duration: Optional[str],
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id, "max_budget": max_budget}
@ -908,9 +937,11 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers)
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id}
data["metadata"]["headers"] = dict(request.headers)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
@ -993,10 +1024,12 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
if "metadata" in data:
print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers)
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
@ -1092,9 +1125,12 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
# check if non-openai/azure model called - e.g. for langchain integration
@ -1173,9 +1209,12 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth =
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model
@ -1231,7 +1270,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
# data = await request.json()
data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@ -1287,6 +1325,52 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
detail={"error": str(e)},
)
#### USER MANAGEMENT ####
@router.post("/user/new", tags=["user management"], dependencies=[Depends(user_api_key_auth)], response_model=NewUserResponse)
async def new_user(data: NewUserRequest):
"""
Use this to create a new user with a budget.
Returns user id, budget + new key.
Parameters:
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
- max_budget: Optional[float] - Specify max budget for a given user.
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns:
- key: (str) The generated api key
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
- max_budget: (float|None) Max budget for given user.
"""
data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json)
return NewUserResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"], max_budget=response["max_budget"])
@router.post("/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
async def user_info(request: Request):
"""
[TODO]: Use this to get user information. (user row + all user key info)
"""
pass
@router.post("/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
async def user_update(request: Request):
"""
[TODO]: Use this to update user budget
"""
pass
#### MODEL MANAGEMENT ####
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@ -1512,9 +1596,11 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use
print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli