litellm-mirror/litellm/proxy/auth/auth_checks.py
2024-06-06 21:29:40 -07:00

420 lines
15 KiB
Python

# What is this?
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_TeamTable,
LiteLLMRoutes,
LiteLLM_OrganizationTable,
LitellmUserRoles,
)
from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.caching import DualCache
import litellm
from opentelemetry.trace import Span
from functools import wraps
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from datetime import datetime
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def log_to_opentelemetry(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now()
result = await func(*args, **kwargs)
end_time = datetime.now()
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
if (
"parent_otel_span" in kwargs
and kwargs["parent_otel_span"] is not None
and "proxy_logging_obj" in kwargs
and kwargs["proxy_logging_obj"] is not None
):
proxy_logging_obj = kwargs["proxy_logging_obj"]
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs["parent_otel_span"],
start_time=start_time,
end_time=end_time,
)
# end of logging to otel
return result
return wrapper
def common_checks(
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float],
general_settings: dict,
route: str,
) -> bool:
"""
Common checks across jwt + key-based auth.
1. If team is blocked
2. If team can call model
3. If team is in budget
5. If user passed in (JWT or key.user_id) - is in budget
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
"""
_model = request_body.get("model", None)
if team_object is not None and team_object.blocked == True:
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
# 2. If user can call model
if (
_model is not None
and team_object is not None
and len(team_object.models) > 0
and _model not in team_object.models
):
# this means the team has access to all models on the proxy
if "all-proxy-models" in team_object.models:
# this means the team has access to all models on the proxy
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
)
# 3. If team is in budget
if (
team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
raise Exception(
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
)
if user_object is not None and user_object.max_budget is not None:
user_budget = user_object.max_budget
if user_budget > user_object.spend:
raise Exception(
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
)
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise Exception(
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
)
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if (
general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True
):
if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body:
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if (
litellm.max_budget > 0
and global_proxy_spend is not None
# only run global budget checks for OpenAI routes
# Reason - the Admin UI should continue working if the proxy crosses it's global budget
and route in LiteLLMRoutes.openai_routes.value
and route != "/v1/models"
and route != "/models"
):
if global_proxy_spend > litellm.max_budget:
raise Exception(
f"ExceededBudget: LiteLLM Proxy has exceeded its budget. Current spend: {global_proxy_spend}; Max Budget: {litellm.max_budget}"
)
return True
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
"""
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
Parameters:
- user_route: str - the route the user is trying to call
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
"""
for allowed_route in allowed_routes:
if (
allowed_route in LiteLLMRoutes.__members__
and user_route in LiteLLMRoutes[allowed_route].value
):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check(
user_role: Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.TEAM,
LitellmUserRoles.INTERNAL_USER,
],
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
if user_role == LitellmUserRoles.PROXY_ADMIN:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == LitellmUserRoles.TEAM:
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
except KeyError:
actual_routes.append(route_name)
return actual_routes
@log_to_opentelemetry
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object, if in db.
Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user).
"""
if prisma_client is None:
raise Exception("No db connected")
if end_user_id is None:
return None
_key = "end_user_id:{}".format(end_user_id)
def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
if end_user_obj.litellm_budget_table is None:
return
end_user_budget = end_user_obj.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_obj.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_obj.spend, max_budget=end_user_budget
)
# check if in cache
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return_obj = LiteLLM_EndUserTable(**cached_user_obj)
check_in_budget(end_user_obj=return_obj)
return return_obj
elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
return_obj = cached_user_obj
check_in_budget(end_user_obj=return_obj)
return return_obj
# else, check db
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id},
include={"litellm_budget_table": True},
)
if response is None:
raise Exception
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key="end_user_id:{}".format(end_user_id), value=response
)
_response = LiteLLM_EndUserTable(**response.dict())
check_in_budget(end_user_obj=_response)
return _response
except Exception as e: # if end-user not in db
if isinstance(e, litellm.BudgetExceededError):
raise e
return None
@log_to_opentelemetry
async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception("No db connected")
if user_id is None:
return None
# check if in cache
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
data={"user_id": user_id}
)
else:
raise Exception
_response = LiteLLM_UserTable(**dict(response))
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
return _response
except Exception as e: # if user not in db
raise ValueError(
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
)
@log_to_opentelemetry
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
return cached_team_obj
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None:
raise Exception
_response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
return _response
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)
@log_to_opentelemetry
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
"""
- Check if org id in proxy Org Table
- if valid, return LiteLLM_OrganizationTable object
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id))
if cached_org_obj is not None:
if isinstance(cached_org_obj, dict):
return cached_org_obj
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable):
return cached_org_obj
# else, check db
try:
response = await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": org_id}
)
if response is None:
raise Exception
return response
except Exception as e:
raise Exception(
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
)