fixes for auth checks

This commit is contained in:
Ishaan Jaff 2025-03-25 15:44:13 -07:00
parent 59040167ac
commit ce49e27217
4 changed files with 71 additions and 56 deletions

View file

@ -2141,6 +2141,13 @@ class ProxyErrorTypes(str, enum.Enum):
return cls.user_model_access_denied return cls.user_model_access_denied
DB_CONNECTION_ERROR_TYPES = (
httpx.ConnectError,
httpx.ReadError,
httpx.ReadTimeout,
)
class SSOUserDefinedValues(TypedDict): class SSOUserDefinedValues(TypedDict):
models: List[str] models: List[str]
user_id: str user_id: str

View file

@ -987,33 +987,34 @@ async def get_key_object(
) )
# else, check db # else, check db
try: _valid_token: Optional[BaseModel] = await prisma_client.get_data(
_valid_token: Optional[BaseModel] = await prisma_client.get_data( token=hashed_token,
token=hashed_token, table_name="combined_view",
table_name="combined_view", parent_otel_span=parent_otel_span,
parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj,
proxy_logging_obj=proxy_logging_obj, )
if _valid_token is None:
raise ProxyException(
message="Key doesn't exist in db. key={}. Create key via `/key/generate` call.".format(
hashed_token
),
type=ProxyErrorTypes.token_not_found_in_db,
param="key",
code=status.HTTP_401_UNAUTHORIZED,
) )
if _valid_token is None: _response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
raise Exception
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) # save the key object to cache
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# save the key object to cache return _response
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _response
except Exception:
traceback.print_exc()
raise Exception(
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
)
@log_db_metrics @log_db_metrics

View file

@ -9,7 +9,12 @@ from fastapi import HTTPException, Request, status
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
ProxyErrorTypes,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import _get_request_ip_address from litellm.proxy.auth.auth_utils import _get_request_ip_address
from litellm.types.services import ServiceTypes from litellm.types.services import ServiceTypes
@ -52,7 +57,10 @@ class UserAPIKeyAuthExceptionHandler:
proxy_logging_obj, proxy_logging_obj,
) )
if UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable(): if (
UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable()
and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e)
):
# log this as a DB failure on prometheus # log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook( proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB, service=ServiceTypes.DB,
@ -128,3 +136,14 @@ class UserAPIKeyAuthExceptionHandler:
if general_settings.get("allow_requests_on_db_unavailable", False) is True: if general_settings.get("allow_requests_on_db_unavailable", False) is True:
return True return True
return False return False
@staticmethod
def is_database_connection_error(e: Exception) -> bool:
"""
Returns True if the exception is from a database outage / connection error
"""
import prisma
return isinstance(e, DB_CONNECTION_ERROR_TYPES) or isinstance(
e, prisma.errors.PrismaError
)

View file

@ -683,37 +683,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
api_key = hash_token(token=api_key) api_key = hash_token(token=api_key)
if valid_token is None: if valid_token is None:
try: valid_token = await get_key_object(
valid_token = await get_key_object( hashed_token=api_key,
hashed_token=api_key, prisma_client=prisma_client,
prisma_client=prisma_client, user_api_key_cache=user_api_key_cache,
user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span,
parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj,
proxy_logging_obj=proxy_logging_obj, )
) # update end-user params on valid token
# update end-user params on valid token # These can change per request - it's important to update them here
# These can change per request - it's important to update them here valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_id = end_user_params.get("end_user_id") valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_tpm_limit = end_user_params.get( valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
"end_user_tpm_limit" valid_token.allowed_model_region = end_user_params.get(
) "allowed_model_region"
valid_token.end_user_rpm_limit = end_user_params.get( )
"end_user_rpm_limit" # update key budget with temp budget increase
) valid_token = _update_key_budget_with_temp_budget_increase(
valid_token.allowed_model_region = end_user_params.get( valid_token
"allowed_model_region" ) # updating it here, allows all downstream reporting / checks to use the updated budget
)
# update key budget with temp budget increase
valid_token = _update_key_budget_with_temp_budget_increase(
valid_token
) # updating it here, allows all downstream reporting / checks to use the updated budget
except Exception:
verbose_logger.info(
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
api_key
)
)
valid_token = None
if valid_token is None: if valid_token is None:
raise Exception( raise Exception(