mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'main' into litellm_spend_logging_high_traffic
This commit is contained in:
commit
df60edfa07
36 changed files with 638 additions and 291 deletions
|
@ -91,6 +91,7 @@ from litellm.proxy.utils import (
|
|||
reset_budget,
|
||||
hash_token,
|
||||
html_form,
|
||||
_read_request_body,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
import pydantic
|
||||
|
@ -322,6 +323,7 @@ async def user_api_key_auth(
|
|||
f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}"
|
||||
)
|
||||
|
||||
### CHECK IF ADMIN ###
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||
if is_master_key_valid:
|
||||
|
@ -370,8 +372,9 @@ async def user_api_key_auth(
|
|||
# Run checks for
|
||||
# 1. If token can call model
|
||||
# 2. If user_id for this token is in budget
|
||||
# 3. If token is expired
|
||||
# 4. If token spend is under Budget for the token
|
||||
# 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
||||
# 4. If token is expired
|
||||
# 5. If token spend is under Budget for the token
|
||||
|
||||
# Check 1. If token can call model
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
|
@ -430,9 +433,18 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
## Check 2.5 If global proxy is in budget
|
||||
## Check 2.1 If global proxy is in budget
|
||||
## Check 2.2 [OPTIONAL - checked only if litellm.max_user_budget is not None] If 'user' passed in /chat/completions is in budget
|
||||
if valid_token.user_id is not None:
|
||||
user_id_list = [valid_token.user_id, litellm_proxy_budget_name]
|
||||
if (
|
||||
litellm.max_user_budget is not None
|
||||
): # Check if 'user' passed in /chat/completions is in budget, only checked if litellm.max_user_budget is set
|
||||
request_data = await _read_request_body(request=request)
|
||||
user_passed_to_chat_completions = request_data.get("user", None)
|
||||
if user_passed_to_chat_completions is not None:
|
||||
user_id_list.append(user_passed_to_chat_completions)
|
||||
|
||||
user_id_information = None
|
||||
for id in user_id_list:
|
||||
value = user_api_key_cache.get_cache(key=id)
|
||||
|
@ -461,6 +473,7 @@ async def user_api_key_auth(
|
|||
user_id_information = await custom_db_client.get_data(
|
||||
key=valid_token.user_id, table_name="user"
|
||||
)
|
||||
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id_information: {user_id_information}"
|
||||
|
@ -473,12 +486,18 @@ async def user_api_key_auth(
|
|||
if _user is None:
|
||||
continue
|
||||
assert isinstance(_user, dict)
|
||||
# check if user is admin #
|
||||
if (
|
||||
_user.get("user_role", None) is not None
|
||||
and _user.get("user_role") == "proxy_admin"
|
||||
):
|
||||
return UserAPIKeyAuth(api_key=master_key)
|
||||
# Token exists, not expired now check if its in budget for the user
|
||||
user_max_budget = _user.get("max_budget", None)
|
||||
user_current_spend = _user.get("spend", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
|
||||
f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -616,10 +635,13 @@ async def user_api_key_auth(
|
|||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
detail="key not allowed to access this user's info",
|
||||
)
|
||||
elif route == "/user/update":
|
||||
raise HTTPException(
|
||||
|
@ -860,22 +882,28 @@ async def update_database(
|
|||
- Update that user's row
|
||||
- Update litellm-proxy-budget row (global proxy spend)
|
||||
"""
|
||||
try:
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
continue
|
||||
if prisma_client is not None:
|
||||
existing_spend_obj = await prisma_client.get_data(user_id=id)
|
||||
elif (
|
||||
custom_db_client is not None and id != litellm_proxy_budget_name
|
||||
):
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=id, table_name="user"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||
user_ids = [user_id, litellm_proxy_budget_name]
|
||||
data_list = []
|
||||
for id in user_ids:
|
||||
if id is None:
|
||||
continue
|
||||
if prisma_client is not None:
|
||||
existing_spend_obj = await prisma_client.get_data(user_id=id)
|
||||
elif custom_db_client is not None and id != litellm_proxy_budget_name:
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=id, table_name="user"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_UserTable(
|
||||
user_id=id, spend=0, max_budget=max_user_budget, user_email=None
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
|
@ -1147,7 +1175,7 @@ class ProxyConfig:
|
|||
# load existing config
|
||||
config = await self.get_config()
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", None)
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
all_teams_config = litellm_settings.get("default_team_settings", None)
|
||||
team_config: dict = {}
|
||||
if all_teams_config is None:
|
||||
|
@ -1791,7 +1819,33 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {str(e)}\n\n"
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
def select_data_generator(response, user_api_key_dict):
|
||||
|
@ -1799,7 +1853,7 @@ def select_data_generator(response, user_api_key_dict):
|
|||
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
|
||||
if hasattr(
|
||||
response, "custom_llm_provider"
|
||||
) and response.custom_llm_provider in ["sagemaker", "together_ai"]:
|
||||
) and response.custom_llm_provider in ["sagemaker"]:
|
||||
return data_generator(
|
||||
response=response,
|
||||
)
|
||||
|
@ -1892,6 +1946,10 @@ async def startup_event():
|
|||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
user_id = "default_user_id"
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
user_id = os.getenv("PROXY_ADMIN_ID")
|
||||
|
||||
asyncio.create_task(
|
||||
generate_key_helper_fn(
|
||||
duration=None,
|
||||
|
@ -1900,7 +1958,12 @@ async def startup_event():
|
|||
config={},
|
||||
spend=0,
|
||||
token=master_key,
|
||||
user_id="default_user_id",
|
||||
user_id=user_id,
|
||||
user_role="proxy_admin",
|
||||
query_type="update_data",
|
||||
update_key_values={
|
||||
"user_role": "proxy_admin",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -2284,7 +2347,6 @@ async def chat_completion(
|
|||
selected_data_generator = select_data_generator(
|
||||
response=response, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
|
@ -3459,23 +3521,38 @@ async def auth_callback(request: Request):
|
|||
result = await microsoft_sso.verify_and_process(request)
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_id = getattr(result, "email", None)
|
||||
user_email = getattr(result, "email", None)
|
||||
user_id = getattr(result, "id", None)
|
||||
if user_id is None:
|
||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
|
||||
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
|
||||
)
|
||||
|
||||
key = response["token"] # type: ignore
|
||||
user_id = response["user_id"] # type: ignore
|
||||
|
||||
litellm_dashboard_ui = "/ui/"
|
||||
|
||||
user_role = "app_owner"
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
):
|
||||
# checks if user is admin
|
||||
user_role = "app_admin"
|
||||
|
||||
import jwt
|
||||
|
||||
jwt_token = jwt.encode(
|
||||
{"user_id": user_id, "key": key}, "secret", algorithm="HS256"
|
||||
{
|
||||
"user_id": user_id,
|
||||
"key": key,
|
||||
"user_email": user_email,
|
||||
"user_role": user_role,
|
||||
},
|
||||
"secret",
|
||||
algorithm="HS256",
|
||||
)
|
||||
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
|
||||
|
||||
|
@ -3488,10 +3565,18 @@ async def auth_callback(request: Request):
|
|||
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def user_info(
|
||||
user_id: str = fastapi.Query(..., description="User ID in the request parameters")
|
||||
user_id: Optional[str] = fastapi.Query(
|
||||
default=None, description="User ID in the request parameters"
|
||||
)
|
||||
):
|
||||
"""
|
||||
Use this to get user information. (user row + all user key info)
|
||||
|
||||
Example request
|
||||
```
|
||||
curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
global prisma_client
|
||||
try:
|
||||
|
@ -3500,11 +3585,25 @@ async def user_info(
|
|||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
## GET USER ROW ##
|
||||
user_info = await prisma_client.get_data(user_id=user_id)
|
||||
if user_id is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id)
|
||||
else:
|
||||
user_info = None
|
||||
## GET ALL KEYS ##
|
||||
keys = await prisma_client.get_data(
|
||||
user_id=user_id, table_name="key", query_type="find_all"
|
||||
user_id=user_id,
|
||||
table_name="key",
|
||||
query_type="find_all",
|
||||
expires=datetime.now(),
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
## make sure we still return a total spend ##
|
||||
spend = 0
|
||||
for k in keys:
|
||||
spend += getattr(k, "spend", 0)
|
||||
user_info = {"spend": spend}
|
||||
|
||||
## REMOVE HASHED TOKEN INFO before returning ##
|
||||
for key in keys:
|
||||
try:
|
||||
|
@ -4109,16 +4208,16 @@ async def health_readiness():
|
|||
cache_type = {"type": cache_type, "index_info": index_info}
|
||||
|
||||
if prisma_client is not None: # if db passed in, check if it's connected
|
||||
if prisma_client.db.is_connected() == True:
|
||||
response_object = {"db": "connected"}
|
||||
await prisma_client.health_check() # test the db connection
|
||||
response_object = {"db": "connected"}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"db": "connected",
|
||||
"cache": cache_type,
|
||||
"litellm_version": version,
|
||||
"success_callbacks": litellm.success_callback,
|
||||
}
|
||||
return {
|
||||
"status": "healthy",
|
||||
"db": "connected",
|
||||
"cache": cache_type,
|
||||
"litellm_version": version,
|
||||
"success_callbacks": litellm.success_callback,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "healthy",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue