Merge branch 'main' into litellm_llm_api_prompt_injection_check

This commit is contained in:
Krish Dholakia 2024-03-21 09:57:10 -07:00 committed by GitHub
commit 33a433eb0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 500 additions and 159 deletions

View file

@ -110,6 +110,7 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
try:
from litellm._version import version
@ -364,20 +365,54 @@ async def user_api_key_auth(
user_id = jwt_handler.get_user_id(
token=valid_token, default_value=litellm_proxy_admin_name
)
end_user_object = None
# get the request body
request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id)
if (
request_data.get("user", None)
and request_data["user"] != user_object.user_id
):
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object
)
# run through common checks
_ = common_checks(
request_body=request_data,
user_object=user_object,
end_user_object=end_user_object,
)
# save user object in cache
await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object
)
# if admin return
if is_admin:
_user_api_key_obj = UserAPIKeyAuth(
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
user_api_key_cache.set_cache(
key=hash_token(api_key), value=_user_api_key_obj
else:
# return UserAPIKeyAuth object
return UserAPIKeyAuth(
api_key=None,
user_id=user_object.user_id,
tpm_limit=user_object.tpm_limit,
rpm_limit=user_object.rpm_limit,
models=user_object.models,
user_role="app_owner",
)
return _user_api_key_obj
else:
raise Exception("Invalid key error!")
#### ELSE ####
if master_key is None:
if isinstance(api_key, str):
@ -442,7 +477,7 @@ async def user_api_key_auth(
user_role="proxy_admin",
user_id=litellm_proxy_admin_name,
)
user_api_key_cache.set_cache(
await user_api_key_cache.async_set_cache(
key=hash_token(master_key), value=_user_api_key_obj
)
@ -607,7 +642,7 @@ async def user_api_key_auth(
query_type="find_all",
)
for _id in user_id_information:
user_api_key_cache.set_cache(
await user_api_key_cache.async_set_cache(
key=_id["user_id"], value=_id, ttl=600
)
if custom_db_client is not None:
@ -795,7 +830,9 @@ async def user_api_key_auth(
api_key = valid_token.token
# Add hashed token to cache
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600)
await user_api_key_cache.async_set_cache(
key=api_key, value=valid_token, ttl=600
)
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
"""
@ -1077,7 +1114,10 @@ async def _PROXY_track_cost_callback(
)
await update_cache(
token=user_api_key, user_id=user_id, response_cost=response_cost
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
)
else:
raise Exception("User API key missing from custom callback.")
@ -1352,9 +1392,10 @@ async def update_database(
async def update_cache(
token,
user_id,
response_cost,
token: Optional[str],
user_id: Optional[str],
end_user_id: Optional[str],
response_cost: Optional[float],
):
"""
Use this to update the cache with new user spend.
@ -1369,12 +1410,17 @@ async def update_cache(
hashed_token = hash_token(token=token)
else:
hashed_token = token
verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}")
existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
verbose_proxy_logger.debug(
f"_update_key_db: existing spend: {existing_spend_obj}"
f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
)
verbose_proxy_logger.debug(
f"_update_key_cache: existing spend: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
existing_spend_obj = LiteLLM_VerificationTokenView()
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
@ -1430,18 +1476,7 @@ async def update_cache(
async def _update_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
end_user_id = None
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = hash_token(token=token)
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
if existing_token_obj is None:
return
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
user_ids = [user_id, litellm_proxy_budget_name, end_user_id]
try:
for _id in user_ids:
# Fetch the existing cost for the given user
@ -1487,9 +1522,59 @@ async def update_cache(
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
)
asyncio.create_task(_update_key_cache())
async def _update_end_user_cache():
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
_id = end_user_id
try:
# Fetch the existing cost for the given user
existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None:
# if user does not exist in LiteLLM_UserTable, create a new user
existing_spend = 0
max_user_budget = None
if litellm.max_user_budget is not None:
max_user_budget = litellm.max_user_budget
existing_spend_obj = LiteLLM_EndUserTable(
user_id=_id,
spend=0,
blocked=False,
litellm_budget_table=LiteLLM_BudgetTable(
max_budget=max_user_budget
),
)
verbose_proxy_logger.debug(
f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
if isinstance(existing_spend_obj, dict):
existing_spend = existing_spend_obj["spend"]
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
else:
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
except Exception as e:
verbose_proxy_logger.debug(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
)
if token is not None:
asyncio.create_task(_update_key_cache())
asyncio.create_task(_update_user_cache())
if end_user_id is not None:
asyncio.create_task(_update_end_user_cache())
def run_ollama_serve():
try:
@ -1881,7 +1966,7 @@ class ProxyConfig:
elif key == "success_callback":
litellm.success_callback = []
# intialize success callbacks
# initialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
@ -1906,7 +1991,7 @@ class ProxyConfig:
elif key == "failure_callback":
litellm.failure_callback = []
# intialize success callbacks
# initialize success callbacks
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
@ -2604,6 +2689,11 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ##
jwt_handler.update_environment(
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache
)
if use_background_health_checks:
asyncio.create_task(
_run_background_health_check()
@ -7771,6 +7861,8 @@ async def shutdown_event():
if litellm.cache is not None:
await litellm.cache.disconnect()
await jwt_handler.close()
## RESET CUSTOM VARIABLES ##
cleanup_router_config_variables()