mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'main' into litellm_llm_api_prompt_injection_check
This commit is contained in:
commit
33a433eb0a
10 changed files with 500 additions and 159 deletions
|
@ -110,6 +110,7 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
|
|||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -364,20 +365,54 @@ async def user_api_key_auth(
|
|||
user_id = jwt_handler.get_user_id(
|
||||
token=valid_token, default_value=litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
end_user_object = None
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
||||
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
||||
if (
|
||||
request_data.get("user", None)
|
||||
and request_data["user"] != user_object.user_id
|
||||
):
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=request_data["user"],
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the end-user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=request_data["user"], value=end_user_object
|
||||
)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
user_object=user_object,
|
||||
end_user_object=end_user_object,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id, value=user_object
|
||||
)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
_user_api_key_obj = UserAPIKeyAuth(
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=user_id,
|
||||
)
|
||||
user_api_key_cache.set_cache(
|
||||
key=hash_token(api_key), value=_user_api_key_obj
|
||||
else:
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
user_id=user_object.user_id,
|
||||
tpm_limit=user_object.tpm_limit,
|
||||
rpm_limit=user_object.rpm_limit,
|
||||
models=user_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
|
||||
return _user_api_key_obj
|
||||
else:
|
||||
raise Exception("Invalid key error!")
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -442,7 +477,7 @@ async def user_api_key_auth(
|
|||
user_role="proxy_admin",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
user_api_key_cache.set_cache(
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=hash_token(master_key), value=_user_api_key_obj
|
||||
)
|
||||
|
||||
|
@ -607,7 +642,7 @@ async def user_api_key_auth(
|
|||
query_type="find_all",
|
||||
)
|
||||
for _id in user_id_information:
|
||||
user_api_key_cache.set_cache(
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id["user_id"], value=_id, ttl=600
|
||||
)
|
||||
if custom_db_client is not None:
|
||||
|
@ -795,7 +830,9 @@ async def user_api_key_auth(
|
|||
api_key = valid_token.token
|
||||
|
||||
# Add hashed token to cache
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600)
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=api_key, value=valid_token, ttl=600
|
||||
)
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
"""
|
||||
|
@ -1077,7 +1114,10 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
|
||||
await update_cache(
|
||||
token=user_api_key, user_id=user_id, response_cost=response_cost
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
else:
|
||||
raise Exception("User API key missing from custom callback.")
|
||||
|
@ -1352,9 +1392,10 @@ async def update_database(
|
|||
|
||||
|
||||
async def update_cache(
|
||||
token,
|
||||
user_id,
|
||||
response_cost,
|
||||
token: Optional[str],
|
||||
user_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
response_cost: Optional[float],
|
||||
):
|
||||
"""
|
||||
Use this to update the cache with new user spend.
|
||||
|
@ -1369,12 +1410,17 @@ async def update_cache(
|
|||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}")
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_db: existing spend: {existing_spend_obj}"
|
||||
f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_key_cache: existing spend: {existing_spend_obj}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
existing_spend_obj = LiteLLM_VerificationTokenView()
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
|
@ -1430,18 +1476,7 @@ async def update_cache(
|
|||
|
||||
async def _update_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
end_user_id = None
|
||||
if isinstance(token, str) and token.startswith("sk-"):
|
||||
hashed_token = hash_token(token=token)
|
||||
else:
|
||||
hashed_token = token
|
||||
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||
if existing_token_obj is None:
|
||||
return
|
||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||
end_user_id = user_id
|
||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
||||
|
||||
user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
|
||||
try:
|
||||
for _id in user_ids:
|
||||
# Fetch the existing cost for the given user
|
||||
|
@ -1487,9 +1522,59 @@ async def update_cache(
|
|||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
asyncio.create_task(_update_key_cache())
|
||||
async def _update_end_user_cache():
|
||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||
_id = end_user_id
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
|
||||
if existing_spend_obj is None:
|
||||
# if user does not exist in LiteLLM_UserTable, create a new user
|
||||
existing_spend = 0
|
||||
max_user_budget = None
|
||||
if litellm.max_user_budget is not None:
|
||||
max_user_budget = litellm.max_user_budget
|
||||
existing_spend_obj = LiteLLM_EndUserTable(
|
||||
user_id=_id,
|
||||
spend=0,
|
||||
blocked=False,
|
||||
litellm_budget_table=LiteLLM_BudgetTable(
|
||||
max_budget=max_user_budget
|
||||
),
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend = existing_spend_obj["spend"]
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
if token is not None:
|
||||
asyncio.create_task(_update_key_cache())
|
||||
|
||||
asyncio.create_task(_update_user_cache())
|
||||
|
||||
if end_user_id is not None:
|
||||
asyncio.create_task(_update_end_user_cache())
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
try:
|
||||
|
@ -1881,7 +1966,7 @@ class ProxyConfig:
|
|||
elif key == "success_callback":
|
||||
litellm.success_callback = []
|
||||
|
||||
# intialize success callbacks
|
||||
# initialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
|
@ -1906,7 +1991,7 @@ class ProxyConfig:
|
|||
elif key == "failure_callback":
|
||||
litellm.failure_callback = []
|
||||
|
||||
# intialize success callbacks
|
||||
# initialize success callbacks
|
||||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
|
@ -2604,6 +2689,11 @@ async def startup_event():
|
|||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
|
@ -7771,6 +7861,8 @@ async def shutdown_event():
|
|||
if litellm.cache is not None:
|
||||
await litellm.cache.disconnect()
|
||||
|
||||
await jwt_handler.close()
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue