fix(proxy_server.py): manage budget at user-level not key-level

https://github.com/BerriAI/litellm/issues/1220
This commit is contained in:
Krrish Dholakia 2023-12-22 15:10:38 +05:30
parent 979575a2a6
commit 89ee9fe400
5 changed files with 220 additions and 79 deletions

View file

@ -121,7 +121,6 @@ class GenerateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UpdateKeyRequest(LiteLLMBase): class UpdateKeyRequest(LiteLLMBase):
key: str key: str
@ -133,7 +132,6 @@ class UpdateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
max_budget: Optional[float] = None
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
@ -148,7 +146,6 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
duration: str = "1h" duration: str = "1h"
metadata: dict = {} metadata: dict = {}
max_budget: Optional[float] = None
class GenerateKeyResponse(LiteLLMBase): class GenerateKeyResponse(LiteLLMBase):
key: str key: str
@ -161,6 +158,11 @@ class _DeleteKeyObject(LiteLLMBase):
class DeleteKeyRequest(LiteLLMBase): class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject] keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """

View file

@ -4,6 +4,7 @@ from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback
class MaxBudgetLimiter(CustomLogger): class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes # Class variables or attributes
@ -14,22 +15,26 @@ class MaxBudgetLimiter(CustomLogger):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") try:
api_key = user_api_key_dict.api_key self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
max_budget = user_api_key_dict.max_budget cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
curr_spend = user_api_key_dict.spend user_row = cache.get_cache(cache_key)
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]
curr_spend = user_row["spend"]
if api_key is None: if max_budget is None:
return return
if max_budget is None: if curr_spend is None:
return return
if curr_spend is None: # CHECK IF REQUEST ALLOWED
return if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")
# CHECK IF REQUEST ALLOWED except HTTPException as e:
if curr_spend >= max_budget: raise e
raise HTTPException(status_code=429, detail="Max budget limit reached.") except Exception as e:
traceback.print_exc()

View file

@ -92,7 +92,8 @@ import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
get_instance_fn, get_instance_fn,
ProxyLogging ProxyLogging,
_cache_user_row
) )
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
@ -258,8 +259,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if is_master_key_valid: if is_master_key_valid:
return UserAPIKeyAuth(api_key=master_key) return UserAPIKeyAuth(api_key=master_key)
if route.startswith("/key/") and not is_master_key_valid: if (route.startswith("/key/") or route.startswith("/user/")) and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys") raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") raise Exception("No connected db.")
@ -283,10 +284,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list llm_model_list = model_list
print("\n new llm router model list", llm_model_list) print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token pass
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
try: try:
data = await request.json() data = await request.json()
@ -300,6 +298,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
api_key = valid_token.token api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
"""
asyncio create task to update the user api key cache with the user db table as well
This makes the user row data accessible to pre-api call hooks.
"""
asyncio.create_task(_cache_user_row(user_id=valid_token.user_id, cache=user_api_key_cache, db=prisma_client))
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
@ -377,32 +381,57 @@ async def track_cost_callback(
response_cost = litellm.completion_cost(completion_response=completion_response) response_cost = litellm.completion_cost(completion_response=completion_response)
print("streaming response_cost", response_cost) print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
if user_api_key and prisma_client: if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost) await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(completion_response=completion_response) response_cost = litellm.completion_cost(completion_response=completion_response)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
if user_api_key and prisma_client: if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost) await update_prisma_database(token=user_api_key, response_cost=response_cost, user_id=user_id)
except Exception as e: except Exception as e:
print(f"error in tracking cost callback - {str(e)}") print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost): async def update_prisma_database(token, response_cost, user_id=None):
try: try:
print(f"Enters prisma db call, token: {token}") print(f"Enters prisma db call, token: {token}; user_id: {user_id}")
# Fetch the existing cost for the given token ### UPDATE USER SPEND ###
existing_spend_obj = await prisma_client.get_data(token=token) async def _update_user_db():
print(f"existing spend: {existing_spend_obj}") if user_id is None:
if existing_spend_obj is None: return
existing_spend = 0 existing_spend_obj = await prisma_client.get_data(user_id=user_id)
else: if existing_spend_obj is None:
existing_spend = existing_spend_obj.spend existing_spend = 0
# Calculate the new cost by adding the existing cost and response_cost else:
new_spend = existing_spend + response_cost existing_spend = existing_spend_obj.spend
print(f"new cost: {new_spend}") # Calculate the new cost by adding the existing cost and response_cost
# Update the cost column for the given token new_spend = existing_spend + response_cost
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"new cost: {new_spend}")
# Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
### UPDATE KEY SPEND ###
async def _update_key_db():
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
print(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}") print(f"Error updating Prisma database: {traceback.format_exc()}")
pass pass
@ -682,7 +711,7 @@ async def generate_key_helper_fn(duration: Optional[str],
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id} return {"token": token, "expires": new_verification_token.expires, "user_id": user_id, "max_budget": max_budget}
@ -908,9 +937,11 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data["model"] = user_model data["model"] = user_model
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers)
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id}
data["metadata"]["headers"] = dict(request.headers)
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
if user_temperature: if user_temperature:
data["temperature"] = user_temperature data["temperature"] = user_temperature
@ -993,10 +1024,12 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
if "metadata" in data: if "metadata" in data:
print(f'received metadata: {data["metadata"]}') print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
@ -1092,9 +1125,12 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
# check if non-openai/azure model called - e.g. for langchain integration # check if non-openai/azure model called - e.g. for langchain integration
@ -1173,9 +1209,12 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth =
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
@ -1231,7 +1270,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- expires: (datetime) Datetime object for when key expires. - expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
""" """
# data = await request.json()
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@ -1287,6 +1325,52 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
detail={"error": str(e)}, detail={"error": str(e)},
) )
#### USER MANAGEMENT ####
@router.post("/user/new", tags=["user management"], dependencies=[Depends(user_api_key_auth)], response_model=NewUserResponse)
async def new_user(data: NewUserRequest):
"""
Use this to create a new user with a budget.
Returns user id, budget + new key.
Parameters:
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
- max_budget: Optional[float] - Specify max budget for a given user.
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
Returns:
- key: (str) The generated api key
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
- max_budget: (float|None) Max budget for given user.
"""
data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json)
return NewUserResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"], max_budget=response["max_budget"])
@router.post("/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
async def user_info(request: Request):
"""
[TODO]: Use this to get user information. (user row + all user key info)
"""
pass
@router.post("/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
async def user_update(request: Request):
"""
[TODO]: Use this to update user budget
"""
pass
#### MODEL MANAGEMENT #### #### MODEL MANAGEMENT ####
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@ -1512,9 +1596,11 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use
print(f'received metadata: {data["metadata"]}') print(f'received metadata: {data["metadata"]}')
data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = dict(request.headers) data["metadata"]["headers"] = dict(request.headers)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli

View file

@ -7,6 +7,12 @@ generator client {
provider = "prisma-client-py" provider = "prisma-client-py"
} }
model LiteLLM_UserTable {
user_id String @unique
max_budget Float?
spend Float @default(0.0)
}
// required for token gen // required for token gen
model LiteLLM_VerificationToken { model LiteLLM_VerificationToken {
token String @unique token String @unique
@ -18,5 +24,4 @@ model LiteLLM_VerificationToken {
user_id String? user_id String?
max_parallel_requests Int? max_parallel_requests Int?
metadata Json @default("{}") metadata Json @default("{}")
max_budget Float?
} }

View file

@ -165,26 +165,35 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
try: try:
# check if plain text or hash response = None
hashed_token = token if token is not None:
if token.startswith("sk-"): # check if plain text or hash
hashed_token = self.hash_token(token=token) hashed_token = token
if expires: if token.startswith("sk-"):
response = await self.db.litellm_verificationtoken.find_first( hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": hashed_token, "token": hashed_token
"expires": {"gte": expires} # Check if the token is not expired
} }
) )
else: return response
response = await self.db.litellm_verificationtoken.find_unique( elif user_id is not None:
where={ response = await self.db.litellm_usertable.find_first( # type: ignore
"token": hashed_token where={
} "user_id": user_id,
) }
return response )
return response
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e raise e
@ -206,6 +215,7 @@ class PrismaClient:
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, 'token': hashed_token,
@ -215,6 +225,16 @@ class PrismaClient:
"update": {} # don't do anything if it already exists "update": {} # don't do anything if it already exists
} }
) )
new_user_row = await self.db.litellm_usertable.upsert(
where={
'user_id': data['user_id']
},
data={
"create": {"user_id": data['user_id'], "max_budget": max_budget},
"update": {} # don't do anything if it already exists
}
)
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
@ -228,26 +248,37 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def update_data(self, token: str, data: dict): async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
""" """
Update existing data Update existing data
""" """
try: try:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
db_data["token"] = token if token is not None:
response = await self.db.litellm_verificationtoken.update( print_verbose(f"token: {token}")
where={ # check if plain text or hash
"token": token if token.startswith("sk-"):
}, token = self.hash_token(token=token)
data={**db_data} # type: ignore db_data["token"] = token
) response = await self.db.litellm_verificationtoken.update(
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") where={
return {"token": token, "data": db_data} "token": token # type: ignore
},
data={**db_data} # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data}
elif user_id is not None:
"""
If data['spend'] + data['user'], update the user table with spend info as well
"""
update_user_row = await self.db.litellm_usertable.update(
where={
'user_id': user_id # type: ignore
},
data={**db_data} # type: ignore
)
return {"user_id": user_id, "data": db_data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
@ -342,4 +373,16 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e: except Exception as e:
raise e raise e
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
"""
Check if a user_id exists in cache,
if not retrieve it.
"""
cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key)
if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id)
cache_value = user_row.model_dump_json()
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
return