mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fixes for auth checks
This commit is contained in:
parent
59040167ac
commit
ce49e27217
4 changed files with 71 additions and 56 deletions
|
@ -2141,6 +2141,13 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||||
return cls.user_model_access_denied
|
return cls.user_model_access_denied
|
||||||
|
|
||||||
|
|
||||||
|
DB_CONNECTION_ERROR_TYPES = (
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadError,
|
||||||
|
httpx.ReadTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SSOUserDefinedValues(TypedDict):
|
class SSOUserDefinedValues(TypedDict):
|
||||||
models: List[str]
|
models: List[str]
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
|
@ -987,33 +987,34 @@ async def get_key_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
||||||
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
token=hashed_token,
|
||||||
token=hashed_token,
|
table_name="combined_view",
|
||||||
table_name="combined_view",
|
parent_otel_span=parent_otel_span,
|
||||||
parent_otel_span=parent_otel_span,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
)
|
||||||
|
|
||||||
|
if _valid_token is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Key doesn't exist in db. key={}. Create key via `/key/generate` call.".format(
|
||||||
|
hashed_token
|
||||||
|
),
|
||||||
|
type=ProxyErrorTypes.token_not_found_in_db,
|
||||||
|
param="key",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _valid_token is None:
|
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
||||||
raise Exception
|
|
||||||
|
|
||||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
# save the key object to cache
|
||||||
|
await _cache_key_object(
|
||||||
|
hashed_token=hashed_token,
|
||||||
|
user_api_key_obj=_response,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
# save the key object to cache
|
return _response
|
||||||
await _cache_key_object(
|
|
||||||
hashed_token=hashed_token,
|
|
||||||
user_api_key_obj=_response,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _response
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise Exception(
|
|
||||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@log_db_metrics
|
@log_db_metrics
|
||||||
|
|
|
@ -9,7 +9,12 @@ from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
|
from litellm.proxy._types import (
|
||||||
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
|
|
||||||
|
@ -52,7 +57,10 @@ class UserAPIKeyAuthExceptionHandler:
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
if UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable():
|
if (
|
||||||
|
UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable()
|
||||||
|
and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e)
|
||||||
|
):
|
||||||
# log this as a DB failure on prometheus
|
# log this as a DB failure on prometheus
|
||||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||||
service=ServiceTypes.DB,
|
service=ServiceTypes.DB,
|
||||||
|
@ -128,3 +136,14 @@ class UserAPIKeyAuthExceptionHandler:
|
||||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_database_connection_error(e: Exception) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the exception is from a database outage / connection error
|
||||||
|
"""
|
||||||
|
import prisma
|
||||||
|
|
||||||
|
return isinstance(e, DB_CONNECTION_ERROR_TYPES) or isinstance(
|
||||||
|
e, prisma.errors.PrismaError
|
||||||
|
)
|
||||||
|
|
|
@ -683,37 +683,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
api_key = hash_token(token=api_key)
|
api_key = hash_token(token=api_key)
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
try:
|
valid_token = await get_key_object(
|
||||||
valid_token = await get_key_object(
|
hashed_token=api_key,
|
||||||
hashed_token=api_key,
|
prisma_client=prisma_client,
|
||||||
prisma_client=prisma_client,
|
user_api_key_cache=user_api_key_cache,
|
||||||
user_api_key_cache=user_api_key_cache,
|
parent_otel_span=parent_otel_span,
|
||||||
parent_otel_span=parent_otel_span,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
)
|
||||||
)
|
# update end-user params on valid token
|
||||||
# update end-user params on valid token
|
# These can change per request - it's important to update them here
|
||||||
# These can change per request - it's important to update them here
|
valid_token.end_user_id = end_user_params.get("end_user_id")
|
||||||
valid_token.end_user_id = end_user_params.get("end_user_id")
|
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
|
||||||
valid_token.end_user_tpm_limit = end_user_params.get(
|
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
|
||||||
"end_user_tpm_limit"
|
valid_token.allowed_model_region = end_user_params.get(
|
||||||
)
|
"allowed_model_region"
|
||||||
valid_token.end_user_rpm_limit = end_user_params.get(
|
)
|
||||||
"end_user_rpm_limit"
|
# update key budget with temp budget increase
|
||||||
)
|
valid_token = _update_key_budget_with_temp_budget_increase(
|
||||||
valid_token.allowed_model_region = end_user_params.get(
|
valid_token
|
||||||
"allowed_model_region"
|
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
||||||
)
|
|
||||||
# update key budget with temp budget increase
|
|
||||||
valid_token = _update_key_budget_with_temp_budget_increase(
|
|
||||||
valid_token
|
|
||||||
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
|
||||||
except Exception:
|
|
||||||
verbose_logger.info(
|
|
||||||
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
|
|
||||||
api_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
valid_token = None
|
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue