forked from phoenix/litellm-mirror
fix(proxy_server.py): manage budget at user-level not key-level
https://github.com/BerriAI/litellm/issues/1220
This commit is contained in:
parent
979575a2a6
commit
89ee9fe400
5 changed files with 220 additions and 79 deletions
|
@ -121,7 +121,6 @@ class GenerateKeyRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
metadata: Optional[dict] = {}
|
metadata: Optional[dict] = {}
|
||||||
max_budget: Optional[float] = None
|
|
||||||
|
|
||||||
class UpdateKeyRequest(LiteLLMBase):
|
class UpdateKeyRequest(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
@ -133,7 +132,6 @@ class UpdateKeyRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
metadata: Optional[dict] = {}
|
metadata: Optional[dict] = {}
|
||||||
max_budget: Optional[float] = None
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||||
"""
|
"""
|
||||||
|
@ -148,7 +146,6 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
duration: str = "1h"
|
duration: str = "1h"
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
max_budget: Optional[float] = None
|
|
||||||
|
|
||||||
class GenerateKeyResponse(LiteLLMBase):
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
@ -161,6 +158,11 @@ class _DeleteKeyObject(LiteLLMBase):
|
||||||
class DeleteKeyRequest(LiteLLMBase):
|
class DeleteKeyRequest(LiteLLMBase):
|
||||||
keys: List[_DeleteKeyObject]
|
keys: List[_DeleteKeyObject]
|
||||||
|
|
||||||
|
class NewUserRequest(GenerateKeyRequest):
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
class NewUserResponse(GenerateKeyResponse):
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
class ConfigGeneralSettings(LiteLLMBase):
|
class ConfigGeneralSettings(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,6 +4,7 @@ from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
import json, traceback
|
||||||
|
|
||||||
class MaxBudgetLimiter(CustomLogger):
|
class MaxBudgetLimiter(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
@ -14,22 +15,26 @@ class MaxBudgetLimiter(CustomLogger):
|
||||||
if litellm.set_verbose is True:
|
if litellm.set_verbose is True:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
|
||||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||||
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
try:
|
||||||
api_key = user_api_key_dict.api_key
|
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
|
||||||
max_budget = user_api_key_dict.max_budget
|
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
||||||
curr_spend = user_api_key_dict.spend
|
user_row = cache.get_cache(cache_key)
|
||||||
|
if user_row is None: # value not yet cached
|
||||||
|
return
|
||||||
|
max_budget = user_row["max_budget"]
|
||||||
|
curr_spend = user_row["spend"]
|
||||||
|
|
||||||
if api_key is None:
|
if max_budget is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if max_budget is None:
|
if curr_spend is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if curr_spend is None:
|
# CHECK IF REQUEST ALLOWED
|
||||||
return
|
if curr_spend >= max_budget:
|
||||||
|
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||||
# CHECK IF REQUEST ALLOWED
|
except HTTPException as e:
|
||||||
if curr_spend >= max_budget:
|
raise e
|
||||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
|
@ -92,7 +92,8 @@ import litellm
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
get_instance_fn,
|
get_instance_fn,
|
||||||
ProxyLogging
|
ProxyLogging,
|
||||||
|
_cache_user_row
|
||||||
)
|
)
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
@ -258,8 +259,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return UserAPIKeyAuth(api_key=master_key)
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
|
|
||||||
if route.startswith("/key/") and not is_master_key_valid:
|
if (route.startswith("/key/") or route.startswith("/user/")) and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users")
|
||||||
|
|
||||||
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
@ -283,10 +284,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
llm_model_list = model_list
|
llm_model_list = model_list
|
||||||
print("\n new llm router model list", llm_model_list)
|
print("\n new llm router model list", llm_model_list)
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
api_key = valid_token.token
|
pass
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
|
||||||
valid_token_dict.pop("token", None)
|
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
@ -300,6 +298,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
|
"""
|
||||||
|
asyncio create task to update the user api key cache with the user db table as well
|
||||||
|
|
||||||
|
This makes the user row data accessible to pre-api call hooks.
|
||||||
|
"""
|
||||||
|
asyncio.create_task(_cache_user_row(user_id=valid_token.user_id, cache=user_api_key_cache, db=prisma_client))
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
|
@ -377,32 +381,57 @@ async def track_cost_callback(
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
print("streaming response_cost", response_cost)
|
print("streaming response_cost", response_cost)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||||
elif kwargs["stream"] == False: # for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
await update_prisma_database(token=user_api_key, response_cost=response_cost, user_id=user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error in tracking cost callback - {str(e)}")
|
print(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
async def update_prisma_database(token, response_cost):
|
async def update_prisma_database(token, response_cost, user_id=None):
|
||||||
try:
|
try:
|
||||||
print(f"Enters prisma db call, token: {token}")
|
print(f"Enters prisma db call, token: {token}; user_id: {user_id}")
|
||||||
# Fetch the existing cost for the given token
|
### UPDATE USER SPEND ###
|
||||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
async def _update_user_db():
|
||||||
print(f"existing spend: {existing_spend_obj}")
|
if user_id is None:
|
||||||
if existing_spend_obj is None:
|
return
|
||||||
existing_spend = 0
|
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
||||||
else:
|
if existing_spend_obj is None:
|
||||||
existing_spend = existing_spend_obj.spend
|
existing_spend = 0
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
else:
|
||||||
new_spend = existing_spend + response_cost
|
existing_spend = existing_spend_obj.spend
|
||||||
|
|
||||||
print(f"new cost: {new_spend}")
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
# Update the cost column for the given token
|
new_spend = existing_spend + response_cost
|
||||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
|
||||||
|
print(f"new cost: {new_spend}")
|
||||||
|
# Update the cost column for the given user id
|
||||||
|
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
|
||||||
|
|
||||||
|
### UPDATE KEY SPEND ###
|
||||||
|
async def _update_key_db():
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||||
|
print(f"existing spend: {existing_spend_obj}")
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
print(f"new cost: {new_spend}")
|
||||||
|
# Update the cost column for the given token
|
||||||
|
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||||
|
tasks = []
|
||||||
|
tasks.append(_update_user_db())
|
||||||
|
tasks.append(_update_key_db())
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
@ -682,7 +711,7 @@ async def generate_key_helper_fn(duration: Optional[str],
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id, "max_budget": max_budget}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -908,9 +937,11 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key, "user_api_key_user_id": user_api_key_dict.user_id}
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
if user_temperature:
|
if user_temperature:
|
||||||
data["temperature"] = user_temperature
|
data["temperature"] = user_temperature
|
||||||
|
@ -993,10 +1024,12 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
print(f'received metadata: {data["metadata"]}')
|
print(f'received metadata: {data["metadata"]}')
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
|
@ -1092,9 +1125,12 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
||||||
# check if non-openai/azure model called - e.g. for langchain integration
|
# check if non-openai/azure model called - e.g. for langchain integration
|
||||||
|
@ -1173,9 +1209,12 @@ async def image_generation(request: Request, user_api_key_dict: UserAPIKeyAuth =
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||||
|
@ -1231,7 +1270,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
||||||
- expires: (datetime) Datetime object for when key expires.
|
- expires: (datetime) Datetime object for when key expires.
|
||||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||||
"""
|
"""
|
||||||
# data = await request.json()
|
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||||
|
@ -1287,6 +1325,52 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
|
||||||
detail={"error": str(e)},
|
detail={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#### USER MANAGEMENT ####
|
||||||
|
|
||||||
|
@router.post("/user/new", tags=["user management"], dependencies=[Depends(user_api_key_auth)], response_model=NewUserResponse)
|
||||||
|
async def new_user(data: NewUserRequest):
|
||||||
|
"""
|
||||||
|
Use this to create a new user with a budget.
|
||||||
|
|
||||||
|
Returns user id, budget + new key.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
||||||
|
- max_budget: Optional[float] - Specify max budget for a given user.
|
||||||
|
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
|
||||||
|
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
||||||
|
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||||
|
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||||
|
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||||
|
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||||
|
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- key: (str) The generated api key
|
||||||
|
- expires: (datetime) Datetime object for when key expires.
|
||||||
|
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||||
|
- max_budget: (float|None) Max budget for given user.
|
||||||
|
"""
|
||||||
|
data_json = data.json() # type: ignore
|
||||||
|
response = await generate_key_helper_fn(**data_json)
|
||||||
|
return NewUserResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"], max_budget=response["max_budget"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def user_info(request: Request):
|
||||||
|
"""
|
||||||
|
[TODO]: Use this to get user information. (user row + all user key info)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@router.post("/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def user_update(request: Request):
|
||||||
|
"""
|
||||||
|
[TODO]: Use this to update user budget
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||||
|
@ -1512,9 +1596,11 @@ async def async_queue_request(request: Request, model: Optional[str] = None, use
|
||||||
print(f'received metadata: {data["metadata"]}')
|
print(f'received metadata: {data["metadata"]}')
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
data["metadata"]["headers"] = dict(request.headers)
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
|
|
|
@ -7,6 +7,12 @@ generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model LiteLLM_UserTable {
|
||||||
|
user_id String @unique
|
||||||
|
max_budget Float?
|
||||||
|
spend Float @default(0.0)
|
||||||
|
}
|
||||||
|
|
||||||
// required for token gen
|
// required for token gen
|
||||||
model LiteLLM_VerificationToken {
|
model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
|
@ -18,5 +24,4 @@ model LiteLLM_VerificationToken {
|
||||||
user_id String?
|
user_id String?
|
||||||
max_parallel_requests Int?
|
max_parallel_requests Int?
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
max_budget Float?
|
|
||||||
}
|
}
|
|
@ -165,26 +165,35 @@ class PrismaClient:
|
||||||
max_time=10, # maximum total time to retry for
|
max_time=10, # maximum total time to retry for
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
|
||||||
try:
|
try:
|
||||||
# check if plain text or hash
|
response = None
|
||||||
hashed_token = token
|
if token is not None:
|
||||||
if token.startswith("sk-"):
|
# check if plain text or hash
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = token
|
||||||
if expires:
|
if token.startswith("sk-"):
|
||||||
response = await self.db.litellm_verificationtoken.find_first(
|
hashed_token = self.hash_token(token=token)
|
||||||
|
if expires:
|
||||||
|
response = await self.db.litellm_verificationtoken.find_first(
|
||||||
|
where={
|
||||||
|
"token": hashed_token,
|
||||||
|
"expires": {"gte": expires} # Check if the token is not expired
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={
|
where={
|
||||||
"token": hashed_token,
|
"token": hashed_token
|
||||||
"expires": {"gte": expires} # Check if the token is not expired
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
return response
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
elif user_id is not None:
|
||||||
where={
|
response = await self.db.litellm_usertable.find_first( # type: ignore
|
||||||
"token": hashed_token
|
where={
|
||||||
}
|
"user_id": user_id,
|
||||||
)
|
}
|
||||||
return response
|
)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
raise e
|
raise e
|
||||||
|
@ -206,6 +215,7 @@ class PrismaClient:
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
db_data["token"] = hashed_token
|
db_data["token"] = hashed_token
|
||||||
|
max_budget = db_data.pop("max_budget", None)
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
'token': hashed_token,
|
'token': hashed_token,
|
||||||
|
@ -215,6 +225,16 @@ class PrismaClient:
|
||||||
"update": {} # don't do anything if it already exists
|
"update": {} # don't do anything if it already exists
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||||||
|
where={
|
||||||
|
'user_id': data['user_id']
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"create": {"user_id": data['user_id'], "max_budget": max_budget},
|
||||||
|
"update": {} # don't do anything if it already exists
|
||||||
|
}
|
||||||
|
)
|
||||||
return new_verification_token
|
return new_verification_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
@ -228,26 +248,37 @@ class PrismaClient:
|
||||||
max_time=10, # maximum total time to retry for
|
max_time=10, # maximum total time to retry for
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
async def update_data(self, token: str, data: dict):
|
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
|
||||||
"""
|
"""
|
||||||
Update existing data
|
Update existing data
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
print_verbose(f"token: {token}")
|
|
||||||
# check if plain text or hash
|
|
||||||
if token.startswith("sk-"):
|
|
||||||
token = self.hash_token(token=token)
|
|
||||||
|
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
db_data["token"] = token
|
if token is not None:
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
print_verbose(f"token: {token}")
|
||||||
where={
|
# check if plain text or hash
|
||||||
"token": token
|
if token.startswith("sk-"):
|
||||||
},
|
token = self.hash_token(token=token)
|
||||||
data={**db_data} # type: ignore
|
db_data["token"] = token
|
||||||
)
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
where={
|
||||||
return {"token": token, "data": db_data}
|
"token": token # type: ignore
|
||||||
|
},
|
||||||
|
data={**db_data} # type: ignore
|
||||||
|
)
|
||||||
|
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||||
|
return {"token": token, "data": db_data}
|
||||||
|
elif user_id is not None:
|
||||||
|
"""
|
||||||
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||||||
|
"""
|
||||||
|
update_user_row = await self.db.litellm_usertable.update(
|
||||||
|
where={
|
||||||
|
'user_id': user_id # type: ignore
|
||||||
|
},
|
||||||
|
data={**db_data} # type: ignore
|
||||||
|
)
|
||||||
|
return {"user_id": user_id, "data": db_data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||||
|
@ -342,4 +373,16 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### HELPER FUNCTIONS ###
|
||||||
|
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||||
|
"""
|
||||||
|
Check if a user_id exists in cache,
|
||||||
|
if not retrieve it.
|
||||||
|
"""
|
||||||
|
cache_key = f"{user_id}_user_api_key_user_id"
|
||||||
|
response = cache.get_cache(key=cache_key)
|
||||||
|
if response is None: # Cache miss
|
||||||
|
user_row = await db.get_data(user_id=user_id)
|
||||||
|
cache_value = user_row.model_dump_json()
|
||||||
|
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
|
||||||
|
return
|
Loading…
Add table
Add a link
Reference in a new issue