diff --git a/litellm/llms/custom_httpx/bedrock_async.py b/litellm/llms/custom_httpx/bedrock_async.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 0f877087c..981028134 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -599,6 +599,8 @@ class LiteLLM_UserTable(LiteLLMBase): model_spend: Optional[Dict] = {} user_email: Optional[str] models: list = [] + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None @root_validator(pre=True) def set_model_info(cls, values): @@ -617,6 +619,7 @@ class LiteLLM_EndUserTable(LiteLLMBase): blocked: bool alias: Optional[str] = None spend: float = 0.0 + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None @root_validator(pre=True) def set_model_info(cls, values): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py new file mode 100644 index 000000000..cd326cc6d --- /dev/null +++ b/litellm/proxy/auth/auth_checks.py @@ -0,0 +1,84 @@ +# What is this? +## Common auth checks between jwt + key based auth +""" +Got Valid Token from Cache, DB +Run checks for: + +1. If user can call model +2. If user is in budget +3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget +""" +from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable +from typing import Optional +from litellm.proxy.utils import PrismaClient +from litellm.caching import DualCache + + +def common_checks( + request_body: dict, + user_object: LiteLLM_UserTable, + end_user_object: Optional[LiteLLM_EndUserTable], +) -> bool: + _model = request_body.get("model", None) + # 1. If user can call model + if ( + _model is not None + and len(user_object.models) > 0 + and _model not in user_object.models + ): + raise Exception( + f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}" + ) + # 2. If user is in budget + if ( + user_object.max_budget is not None + and user_object.spend > user_object.max_budget + ): + raise Exception( + f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}" + ) + # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + if end_user_object is not None and end_user_object.litellm_budget_table is not None: + end_user_budget = end_user_object.litellm_budget_table.max_budget + if end_user_budget is not None and end_user_object.spend > end_user_budget: + raise Exception( + f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" + ) + return True + + +async def get_end_user_object( + end_user_id: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> Optional[LiteLLM_EndUserTable]: + """ + Returns end user object, if in db. + + Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user). + """ + if prisma_client is None: + raise Exception("No db connected") + + if end_user_id is None: + return None + + # check if in cache + cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_EndUserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_EndUserTable): + return cached_user_obj + # else, check db + try: + response = await prisma_client.db.litellm_endusertable.find_unique( + where={"user_id": end_user_id} + ) + + if response is None: + raise Exception + + return LiteLLM_EndUserTable(**response.dict()) + except Exception as e: # if end-user not in db + return None diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 4342f3365..ad69543d5 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -8,23 +8,27 @@ JWT token must have 'litellm_proxy_admin' in scope. import httpx import jwt - -print(jwt.__version__) # noqa from jwt.algorithms import RSAAlgorithm import json import os -from litellm.proxy._types import LiteLLMProxyRoles +from litellm.caching import DualCache +from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable +from litellm.proxy.utils import PrismaClient from typing import Optional class HTTPHandler: - def __init__(self): - self.client = httpx.AsyncClient() + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ) + ) - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def close(self): + # Close the client when you're done with it await self.client.aclose() async def get( @@ -47,10 +51,27 @@ class HTTPHandler: class JWTHandler: + """ + - treat the sub id passed in as the user id + - return an error if id making request doesn't exist in proxy user table + - track spend against the user id + - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets + """ - def __init__(self) -> None: + prisma_client: Optional[PrismaClient] + user_api_key_cache: DualCache + + def __init__( + self, + ) -> None: self.http_handler = HTTPHandler() + def update_environment( + self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache + ) -> None: + self.prisma_client = prisma_client + self.user_api_key_cache = user_api_key_cache + def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 @@ -67,6 +88,46 @@ class JWTHandler: user_id = default_value return user_id + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + team_id = token["azp"] + except KeyError: + team_id = default_value + return team_id + + async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: + """ + - Check if user id in proxy User Table + - if valid, return LiteLLM_UserTable object with defined limits + - if not, then raise an error + """ + if self.prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_UserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_UserTable): + return cached_user_obj + # else, check db + try: + response = await self.prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + + if response is None: + raise Exception + + return LiteLLM_UserTable(**response.dict()) + except Exception as e: + raise Exception( + f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." + ) + def get_scopes(self, token: dict) -> list: try: # Assuming the scopes are stored in 'scope' claim and are space-separated @@ -78,8 +139,10 @@ class JWTHandler: async def auth_jwt(self, token: str) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") - async with self.http_handler as http: - response = await http.get(keys_url) + if keys_url is None: + raise Exception("Missing JWT Public Key URL from environment.") + + response = await self.http_handler.get(keys_url) keys = response.json()["keys"] @@ -113,3 +176,6 @@ class JWTHandler: raise Exception(f"Validation fails: {str(e)}") raise Exception("Invalid JWT Submitted") + + async def close(self): + await self.http_handler.close() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index dfaa99162..2c0cd92ed 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -107,6 +107,7 @@ from litellm.caching import DualCache from litellm.proxy.health_check import perform_health_check from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object try: from litellm._version import version @@ -360,18 +361,54 @@ async def user_api_key_auth( user_id = jwt_handler.get_user_id( token=valid_token, default_value=litellm_proxy_admin_name ) + + end_user_object = None + # get the request body + request_data = await _read_request_body(request=request) + # get user obj from cache/db -> run for admin too. Ensures, admin client id in db. + user_object = await jwt_handler.get_user_object(user_id=user_id) + if ( + request_data.get("user", None) + and request_data["user"] != user_object.user_id + ): + # get the end-user object + end_user_object = await get_end_user_object( + end_user_id=request_data["user"], + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + # save the end-user object to cache + await user_api_key_cache.async_set_cache( + key=request_data["user"], value=end_user_object + ) + + # run through common checks + _ = common_checks( + request_body=request_data, + user_object=user_object, + end_user_object=end_user_object, + ) + # save user object in cache + await user_api_key_cache.async_set_cache( + key=user_object.user_id, value=user_object + ) # if admin return if is_admin: - _user_api_key_obj = UserAPIKeyAuth( + return UserAPIKeyAuth( api_key=api_key, user_role="proxy_admin", user_id=user_id, ) - user_api_key_cache.set_cache( - key=hash_token(api_key), value=_user_api_key_obj + else: + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + user_id=user_object.user_id, + tpm_limit=user_object.tpm_limit, + rpm_limit=user_object.rpm_limit, + models=user_object.models, + user_role="app_owner", ) - - return _user_api_key_obj else: raise Exception("Invalid key error!") #### ELSE #### @@ -438,7 +475,7 @@ async def user_api_key_auth( user_role="proxy_admin", user_id=litellm_proxy_admin_name, ) - user_api_key_cache.set_cache( + await user_api_key_cache.async_set_cache( key=hash_token(master_key), value=_user_api_key_obj ) @@ -603,7 +640,7 @@ async def user_api_key_auth( query_type="find_all", ) for _id in user_id_information: - user_api_key_cache.set_cache( + await user_api_key_cache.async_set_cache( key=_id["user_id"], value=_id, ttl=600 ) if custom_db_client is not None: @@ -791,7 +828,9 @@ async def user_api_key_auth( api_key = valid_token.token # Add hashed token to cache - user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600) + await user_api_key_cache.async_set_cache( + key=api_key, value=valid_token, ttl=600 + ) valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) """ @@ -1073,7 +1112,10 @@ async def _PROXY_track_cost_callback( ) await update_cache( - token=user_api_key, user_id=user_id, response_cost=response_cost + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, ) else: raise Exception("User API key missing from custom callback.") @@ -1348,9 +1390,10 @@ async def update_database( async def update_cache( - token, - user_id, - response_cost, + token: Optional[str], + user_id: Optional[str], + end_user_id: Optional[str], + response_cost: Optional[float], ): """ Use this to update the cache with new user spend. @@ -1365,12 +1408,17 @@ async def update_cache( hashed_token = hash_token(token=token) else: hashed_token = token + verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}") existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token) verbose_proxy_logger.debug( - f"_update_key_db: existing spend: {existing_spend_obj}" + f"_update_key_cache: existing_spend_obj={existing_spend_obj}" + ) + verbose_proxy_logger.debug( + f"_update_key_cache: existing spend: {existing_spend_obj}" ) if existing_spend_obj is None: existing_spend = 0 + existing_spend_obj = LiteLLM_VerificationTokenView() else: existing_spend = existing_spend_obj.spend # Calculate the new cost by adding the existing cost and response_cost @@ -1426,18 +1474,7 @@ async def update_cache( async def _update_user_cache(): ## UPDATE CACHE FOR USER ID + GLOBAL PROXY - end_user_id = None - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token - existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) - if existing_token_obj is None: - return - if existing_token_obj.user_id != user_id: # an end-user id was passed in - end_user_id = user_id - user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] - + user_ids = [user_id, litellm_proxy_budget_name, end_user_id] try: for _id in user_ids: # Fetch the existing cost for the given user @@ -1483,9 +1520,59 @@ async def update_cache( f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" ) - asyncio.create_task(_update_key_cache()) + async def _update_end_user_cache(): + ## UPDATE CACHE FOR USER ID + GLOBAL PROXY + _id = end_user_id + try: + # Fetch the existing cost for the given user + existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id) + if existing_spend_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_spend_obj = LiteLLM_EndUserTable( + user_id=_id, + spend=0, + blocked=False, + litellm_budget_table=LiteLLM_BudgetTable( + max_budget=max_user_budget + ), + ) + verbose_proxy_logger.debug( + f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + if isinstance(existing_spend_obj, dict): + existing_spend = existing_spend_obj["spend"] + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + # Update the cost column for the given user + if isinstance(existing_spend_obj, dict): + existing_spend_obj["spend"] = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj) + else: + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json()) + except Exception as e: + verbose_proxy_logger.debug( + f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" + ) + + if token is not None: + asyncio.create_task(_update_key_cache()) + asyncio.create_task(_update_user_cache()) + if end_user_id is not None: + asyncio.create_task(_update_end_user_cache()) + def run_ollama_serve(): try: @@ -2587,6 +2674,11 @@ async def startup_event(): proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + ## JWT AUTH ## + jwt_handler.update_environment( + prisma_client=prisma_client, user_api_key_cache=user_api_key_cache + ) + if use_background_health_checks: asyncio.create_task( _run_background_health_check() @@ -7750,6 +7842,8 @@ async def shutdown_event(): if litellm.cache is not None: await litellm.cache.disconnect() + await jwt_handler.close() + ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables()