fixes for auth checks

This commit is contained in:
Ishaan Jaff 2025-03-25 15:44:13 -07:00
parent 59040167ac
commit ce49e27217
4 changed files with 71 additions and 56 deletions

View file

@ -2141,6 +2141,13 @@ class ProxyErrorTypes(str, enum.Enum):
return cls.user_model_access_denied
DB_CONNECTION_ERROR_TYPES = (
httpx.ConnectError,
httpx.ReadError,
httpx.ReadTimeout,
)
class SSOUserDefinedValues(TypedDict):
models: List[str]
user_id: str

View file

@ -987,33 +987,34 @@ async def get_key_object(
)
# else, check db
try:
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
token=hashed_token,
table_name="combined_view",
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
token=hashed_token,
table_name="combined_view",
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if _valid_token is None:
raise ProxyException(
message="Key doesn't exist in db. key={}. Create key via `/key/generate` call.".format(
hashed_token
),
type=ProxyErrorTypes.token_not_found_in_db,
param="key",
code=status.HTTP_401_UNAUTHORIZED,
)
if _valid_token is None:
raise Exception
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
# save the key object to cache
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# save the key object to cache
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _response
except Exception:
traceback.print_exc()
raise Exception(
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
)
return _response
@log_db_metrics

View file

@ -9,7 +9,12 @@ from fastapi import HTTPException, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
ProxyErrorTypes,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import _get_request_ip_address
from litellm.types.services import ServiceTypes
@ -52,7 +57,10 @@ class UserAPIKeyAuthExceptionHandler:
proxy_logging_obj,
)
if UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable():
if (
UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable()
and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e)
):
# log this as a DB failure on prometheus
proxy_logging_obj.service_logging_obj.service_failure_hook(
service=ServiceTypes.DB,
@ -128,3 +136,14 @@ class UserAPIKeyAuthExceptionHandler:
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
return True
return False
@staticmethod
def is_database_connection_error(e: Exception) -> bool:
"""
Returns True if the exception is from a database outage / connection error
"""
import prisma
return isinstance(e, DB_CONNECTION_ERROR_TYPES) or isinstance(
e, prisma.errors.PrismaError
)

View file

@ -683,37 +683,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
api_key = hash_token(token=api_key)
if valid_token is None:
try:
valid_token = await get_key_object(
hashed_token=api_key,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get(
"end_user_tpm_limit"
)
valid_token.end_user_rpm_limit = end_user_params.get(
"end_user_rpm_limit"
)
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
# update key budget with temp budget increase
valid_token = _update_key_budget_with_temp_budget_increase(
valid_token
) # updating it here, allows all downstream reporting / checks to use the updated budget
except Exception:
verbose_logger.info(
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
api_key
)
)
valid_token = None
valid_token = await get_key_object(
hashed_token=api_key,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
# update key budget with temp budget increase
valid_token = _update_key_budget_with_temp_budget_increase(
valid_token
) # updating it here, allows all downstream reporting / checks to use the updated budget
if valid_token is None:
raise Exception(