mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge pull request #9533 from BerriAI/litellm_stability_fixes
[Reliability Fixes] - Gracefully handle exceptions when DB is having an outage
This commit is contained in:
commit
0155b0eba2
10 changed files with 475 additions and 316 deletions
|
@ -1776,6 +1776,7 @@ response = completion(
|
|||
)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
@ -1820,11 +1821,13 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
### SSO Login (AWS Profile)
|
||||
- Set `AWS_PROFILE` environment variable
|
||||
- Make bedrock completion call
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -1940,9 +1943,6 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/images/generations' \
|
|||
"colorGuidedGenerationParams":{"colors":["#FFFFFF"]}
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
| Model Name | Function Call |
|
||||
|-------------------------|---------------------------------------------|
|
||||
|
|
|
@ -160,7 +160,7 @@ general_settings:
|
|||
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
|
||||
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
|
||||
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
|
||||
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
|
||||
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key [Doc on graceful db unavailability](prod#5-if-running-litellm-on-vpc-gracefully-handle-db-unavailability) |
|
||||
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
|
||||
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
|
||||
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
|
||||
|
|
|
@ -94,15 +94,29 @@ This disables the load_dotenv() functionality, which will automatically load you
|
|||
|
||||
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
|
||||
|
||||
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
|
||||
When running LiteLLM on a VPC (and inaccessible from the public internet), you can enable graceful degradation so that request processing continues even if the database is temporarily unavailable.
|
||||
|
||||
|
||||
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
|
||||
|
||||
```yaml
|
||||
#### Configuration
|
||||
|
||||
```yaml showLineNumbers title="litellm config.yaml"
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: True
|
||||
```
|
||||
|
||||
#### Expected Behavior
|
||||
|
||||
When `allow_requests_on_db_unavailable` is set to `true`, LiteLLM will handle errors as follows:
|
||||
|
||||
| Type of Error | Expected Behavior | Details |
|
||||
|---------------|-------------------|----------------|
|
||||
| Prisma Errors | ✅ Request will be allowed | Covers issues like DB connection resets or rejections from the DB via Prisma, the ORM used by LiteLLM. |
|
||||
| Httpx Errors | ✅ Request will be allowed | Occurs when the database is unreachable, allowing the request to proceed despite the DB outage. |
|
||||
| LiteLLM Budget Errors or Model Errors | ❌ Request will be blocked | Triggered when the DB is reachable but the authentication token is invalid, lacks access, or exceeds budget limits. |
|
||||
|
||||
|
||||
## 6. Disable spend_logs & error_logs if not using the LiteLLM UI
|
||||
|
||||
By default, LiteLLM writes several types of logs to the database:
|
||||
|
@ -182,94 +196,4 @@ You should only see the following level of details in logs on the proxy server
|
|||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
```
|
||||
|
||||
|
||||
### Machine Specifications to Deploy LiteLLM
|
||||
|
||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Server | `t2.small`. | `1vCPUs` | `8GB` | `x86` |
|
||||
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
||||
|
||||
|
||||
### Reference Kubernetes Deployment YAML
|
||||
|
||||
Reference Kubernetes `deployment.yaml` that was load tested by us
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: litellm-deployment
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: litellm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: litellm
|
||||
spec:
|
||||
containers:
|
||||
- name: litellm-container
|
||||
image: ghcr.io/berriai/litellm:main-latest
|
||||
imagePullPolicy: Always
|
||||
env:
|
||||
- name: AZURE_API_KEY
|
||||
value: "d6******"
|
||||
- name: AZURE_API_BASE
|
||||
value: "https://ope******"
|
||||
- name: LITELLM_MASTER_KEY
|
||||
value: "sk-1234"
|
||||
- name: DATABASE_URL
|
||||
value: "po**********"
|
||||
args:
|
||||
- "--config"
|
||||
- "/app/proxy_config.yaml" # Update the path to mount the config file
|
||||
volumeMounts: # Define volume mount for proxy_config.yaml
|
||||
- name: config-volume
|
||||
mountPath: /app
|
||||
readOnly: true
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/liveliness
|
||||
port: 4000
|
||||
initialDelaySeconds: 120
|
||||
periodSeconds: 15
|
||||
successThreshold: 1
|
||||
failureThreshold: 3
|
||||
timeoutSeconds: 10
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/readiness
|
||||
port: 4000
|
||||
initialDelaySeconds: 120
|
||||
periodSeconds: 15
|
||||
successThreshold: 1
|
||||
failureThreshold: 3
|
||||
timeoutSeconds: 10
|
||||
volumes: # Define volume to mount proxy_config.yaml
|
||||
- name: config-volume
|
||||
configMap:
|
||||
name: litellm-config
|
||||
|
||||
```
|
||||
|
||||
|
||||
Reference Kubernetes `service.yaml` that was load tested by us
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: litellm-service
|
||||
spec:
|
||||
selector:
|
||||
app: litellm
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 4000
|
||||
targetPort: 4000
|
||||
type: LoadBalancer
|
||||
```
|
||||
```
|
|
@ -2067,16 +2067,68 @@ class SpendCalculateRequest(LiteLLMPydanticObjectBase):
|
|||
|
||||
class ProxyErrorTypes(str, enum.Enum):
|
||||
budget_exceeded = "budget_exceeded"
|
||||
"""
|
||||
Object was over budget
|
||||
"""
|
||||
no_db_connection = "no_db_connection"
|
||||
"""
|
||||
No database connection
|
||||
"""
|
||||
|
||||
token_not_found_in_db = "token_not_found_in_db"
|
||||
"""
|
||||
Requested token was not found in the database
|
||||
"""
|
||||
|
||||
key_model_access_denied = "key_model_access_denied"
|
||||
"""
|
||||
Key does not have access to the model
|
||||
"""
|
||||
|
||||
team_model_access_denied = "team_model_access_denied"
|
||||
"""
|
||||
Team does not have access to the model
|
||||
"""
|
||||
|
||||
user_model_access_denied = "user_model_access_denied"
|
||||
"""
|
||||
User does not have access to the model
|
||||
"""
|
||||
|
||||
expired_key = "expired_key"
|
||||
"""
|
||||
Key has expired
|
||||
"""
|
||||
|
||||
auth_error = "auth_error"
|
||||
"""
|
||||
General authentication error
|
||||
"""
|
||||
|
||||
internal_server_error = "internal_server_error"
|
||||
"""
|
||||
Internal server error
|
||||
"""
|
||||
|
||||
bad_request_error = "bad_request_error"
|
||||
"""
|
||||
Bad request error
|
||||
"""
|
||||
|
||||
not_found_error = "not_found_error"
|
||||
validation_error = "bad_request_error"
|
||||
"""
|
||||
Not found error
|
||||
"""
|
||||
|
||||
validation_error = "validation_error"
|
||||
"""
|
||||
Validation error
|
||||
"""
|
||||
|
||||
cache_ping_error = "cache_ping_error"
|
||||
"""
|
||||
Cache ping error
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_access_error_type_for_object(
|
||||
|
@ -2093,7 +2145,11 @@ class ProxyErrorTypes(str, enum.Enum):
|
|||
return cls.user_model_access_denied
|
||||
|
||||
|
||||
DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout)
|
||||
DB_CONNECTION_ERROR_TYPES = (
|
||||
httpx.ConnectError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
)
|
||||
|
||||
|
||||
class SSOUserDefinedValues(TypedDict):
|
||||
|
|
|
@ -11,7 +11,6 @@ Run checks for:
|
|||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
from fastapi import Request, status
|
||||
|
@ -23,7 +22,6 @@ from litellm.caching.caching import DualCache
|
|||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
RBAC_ROLES,
|
||||
CallInfo,
|
||||
LiteLLM_EndUserTable,
|
||||
|
@ -45,7 +43,6 @@ from litellm.proxy.auth.route_checks import RouteChecks
|
|||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||
from litellm.router import Router
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
from .auth_checks_organization import organization_role_based_access_check
|
||||
|
||||
|
@ -987,75 +984,34 @@ async def get_key_object(
|
|||
)
|
||||
|
||||
# else, check db
|
||||
try:
|
||||
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
||||
token=hashed_token,
|
||||
table_name="combined_view",
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if _valid_token is None:
|
||||
raise Exception
|
||||
|
||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
||||
|
||||
# save the key object to cache
|
||||
await _cache_key_object(
|
||||
hashed_token=hashed_token,
|
||||
user_api_key_obj=_response,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return _response
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise Exception(
|
||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
||||
)
|
||||
|
||||
|
||||
async def _handle_failed_db_connection_for_get_key_object(
|
||||
e: Exception,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
|
||||
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Orignal Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
||||
token=hashed_token,
|
||||
table_name="combined_view",
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# If this flag is on, requests failing to connect to the DB will be allowed
|
||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
if _valid_token is None:
|
||||
raise ProxyException(
|
||||
message="Authentication Error, Invalid proxy server token passed. key={}, not found in db. Create key via `/key/generate` call.".format(
|
||||
hashed_token
|
||||
),
|
||||
type=ProxyErrorTypes.token_not_found_in_db,
|
||||
param="key",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
else:
|
||||
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
|
||||
raise e
|
||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
||||
|
||||
# save the key object to cache
|
||||
await _cache_key_object(
|
||||
hashed_token=hashed_token,
|
||||
user_api_key_obj=_response,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return _response
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
|
|
153
litellm/proxy/auth/auth_exception_handler.py
Normal file
153
litellm/proxy/auth/auth_exception_handler.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
"""
|
||||
Handles Authentication Errors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class UserAPIKeyAuthExceptionHandler:
|
||||
|
||||
@staticmethod
|
||||
async def _handle_authentication_error(
|
||||
e: Exception,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
api_key: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Reliability scenarios this covers:
|
||||
- DB is down and having an outage
|
||||
- Unable to read / recover a key from the DB
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Orignal Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if (
|
||||
UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e)
|
||||
):
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
)
|
||||
else:
|
||||
# raise the exception to the caller
|
||||
requester_ip = _get_request_ip_address(
|
||||
request=request,
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||
str(e),
|
||||
requester_ip,
|
||||
),
|
||||
extra={"requester_ip": requester_ip},
|
||||
)
|
||||
|
||||
# Log this exception to OTEL, Datadog etc
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
parent_otel_span=parent_otel_span,
|
||||
api_key=api_key,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message,
|
||||
type=ProxyErrorTypes.budget_exceeded,
|
||||
param=None,
|
||||
code=400,
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_allow_request_on_db_unavailable() -> bool:
|
||||
"""
|
||||
Returns True if the request should be allowed to proceed despite the DB connection error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_database_connection_error(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is from a database outage / connection error
|
||||
"""
|
||||
import prisma
|
||||
|
||||
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
|
||||
return True
|
||||
if isinstance(e, prisma.errors.PrismaError):
|
||||
return True
|
||||
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
|
||||
return True
|
||||
return False
|
|
@ -26,7 +26,6 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
_get_user_role,
|
||||
_handle_failed_db_connection_for_get_key_object,
|
||||
_is_user_proxy_admin,
|
||||
_virtual_key_max_budget_check,
|
||||
_virtual_key_soft_budget_check,
|
||||
|
@ -38,8 +37,8 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_user_object,
|
||||
is_valid_fallback_model,
|
||||
)
|
||||
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
_get_request_ip_address,
|
||||
get_end_user_id_from_request_body,
|
||||
get_request_route,
|
||||
is_pass_through_provider_route,
|
||||
|
@ -675,8 +674,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
if (
|
||||
prisma_client is None
|
||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
return await _handle_failed_db_connection_for_get_key_object(
|
||||
e=Exception("No connected db.")
|
||||
raise ProxyException(
|
||||
message="No connected db.",
|
||||
type=ProxyErrorTypes.no_db_connection,
|
||||
code=400,
|
||||
param=None,
|
||||
)
|
||||
|
||||
## check for cache hit (In-Memory Cache)
|
||||
|
@ -685,37 +687,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
api_key = hash_token(token=api_key)
|
||||
|
||||
if valid_token is None:
|
||||
try:
|
||||
valid_token = await get_key_object(
|
||||
hashed_token=api_key,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# update end-user params on valid token
|
||||
# These can change per request - it's important to update them here
|
||||
valid_token.end_user_id = end_user_params.get("end_user_id")
|
||||
valid_token.end_user_tpm_limit = end_user_params.get(
|
||||
"end_user_tpm_limit"
|
||||
)
|
||||
valid_token.end_user_rpm_limit = end_user_params.get(
|
||||
"end_user_rpm_limit"
|
||||
)
|
||||
valid_token.allowed_model_region = end_user_params.get(
|
||||
"allowed_model_region"
|
||||
)
|
||||
# update key budget with temp budget increase
|
||||
valid_token = _update_key_budget_with_temp_budget_increase(
|
||||
valid_token
|
||||
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
||||
except Exception:
|
||||
verbose_logger.info(
|
||||
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
|
||||
api_key
|
||||
)
|
||||
)
|
||||
valid_token = None
|
||||
valid_token = await get_key_object(
|
||||
hashed_token=api_key,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# update end-user params on valid token
|
||||
# These can change per request - it's important to update them here
|
||||
valid_token.end_user_id = end_user_params.get("end_user_id")
|
||||
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
|
||||
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
|
||||
valid_token.allowed_model_region = end_user_params.get(
|
||||
"allowed_model_region"
|
||||
)
|
||||
# update key budget with temp budget increase
|
||||
valid_token = _update_key_budget_with_temp_budget_increase(
|
||||
valid_token
|
||||
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
||||
|
||||
if valid_token is None:
|
||||
raise Exception(
|
||||
|
@ -1015,58 +1005,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
route=route,
|
||||
start_time=start_time,
|
||||
)
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
requester_ip = _get_request_ip_address(
|
||||
return await UserAPIKeyAuthExceptionHandler._handle_authentication_error(
|
||||
e=e,
|
||||
request=request,
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||
str(e),
|
||||
requester_ip,
|
||||
),
|
||||
extra={"requester_ip": requester_ip},
|
||||
)
|
||||
|
||||
# Log this exception to OTEL, Datadog etc
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
parent_otel_span=parent_otel_span,
|
||||
api_key=api_key,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message,
|
||||
type=ProxyErrorTypes.budget_exceeded,
|
||||
param=None,
|
||||
code=400,
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
@tracer.wrap()
|
||||
|
|
|
@ -1,37 +1,9 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
|
||||
|
||||
mcp_tools:
|
||||
- name: "get_current_time"
|
||||
description: "Get the current time"
|
||||
input_schema: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "The format of the time to return",
|
||||
"enum": ["short"]
|
||||
}
|
||||
}
|
||||
}
|
||||
handler: "mcp_tools.get_current_time"
|
||||
- name: "get_current_date"
|
||||
description: "Get the current date"
|
||||
input_schema: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "The format of the date to return",
|
||||
"enum": ["short"]
|
||||
}
|
||||
}
|
||||
}
|
||||
handler: "mcp_tools.get_current_date"
|
||||
general_settings:
|
||||
allow_requests_on_db_unavailable: True
|
184
tests/litellm/proxy/auth/test_auth_exception_handler.py
Normal file
184
tests/litellm/proxy/auth/test_auth_exception_handler.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from prisma import errors as prisma_errors
|
||||
from prisma.errors import (
|
||||
ClientNotConnectedError,
|
||||
DataError,
|
||||
ForeignKeyViolationError,
|
||||
HTTPClientClosedError,
|
||||
MissingRequiredValueError,
|
||||
PrismaError,
|
||||
RawQueryError,
|
||||
RecordNotFoundError,
|
||||
TableNotFoundError,
|
||||
UniqueViolationError,
|
||||
)
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||||
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
||||
|
||||
|
||||
# Test is_database_connection_error method
|
||||
@pytest.mark.parametrize(
|
||||
"prisma_error",
|
||||
[
|
||||
PrismaError(),
|
||||
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
UniqueViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
ForeignKeyViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
MissingRequiredValueError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
TableNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RecordNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
HTTPClientClosedError(),
|
||||
ClientNotConnectedError(),
|
||||
],
|
||||
)
|
||||
def test_is_database_connection_error_prisma_errors(prisma_error):
|
||||
"""
|
||||
Test that all Prisma errors are considered database connection errors
|
||||
"""
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
assert handler.is_database_connection_error(prisma_error) == True
|
||||
|
||||
|
||||
def test_is_database_connection_generic_errors():
|
||||
"""
|
||||
Test non-Prisma error cases for database connection checking
|
||||
"""
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Test with ProxyException (DB connection)
|
||||
db_proxy_exception = ProxyException(
|
||||
message="DB Connection Error",
|
||||
type=ProxyErrorTypes.no_db_connection,
|
||||
param="test-param",
|
||||
)
|
||||
assert handler.is_database_connection_error(db_proxy_exception) == True
|
||||
|
||||
# Test with non-DB error
|
||||
regular_exception = Exception("Regular error")
|
||||
assert handler.is_database_connection_error(regular_exception) == False
|
||||
|
||||
|
||||
# Test should_allow_request_on_db_unavailable method
|
||||
@patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"allow_requests_on_db_unavailable": True},
|
||||
)
|
||||
def test_should_allow_request_on_db_unavailable_true():
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
assert handler.should_allow_request_on_db_unavailable() == True
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"allow_requests_on_db_unavailable": False},
|
||||
)
|
||||
def test_should_allow_request_on_db_unavailable_false():
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
assert handler.should_allow_request_on_db_unavailable() == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"prisma_error",
|
||||
[
|
||||
PrismaError(),
|
||||
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
UniqueViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
ForeignKeyViolationError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
MissingRequiredValueError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||
TableNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
RecordNotFoundError(
|
||||
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||
),
|
||||
HTTPClientClosedError(),
|
||||
ClientNotConnectedError(),
|
||||
],
|
||||
)
|
||||
async def test_handle_authentication_error_db_unavailable(prisma_error):
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Mock request and other dependencies
|
||||
mock_request = MagicMock()
|
||||
mock_request_data = {}
|
||||
mock_route = "/test"
|
||||
mock_span = None
|
||||
mock_api_key = "test-key"
|
||||
|
||||
# Test with DB connection error when requests are allowed
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"allow_requests_on_db_unavailable": True},
|
||||
):
|
||||
result = await handler._handle_authentication_error(
|
||||
prisma_error,
|
||||
mock_request,
|
||||
mock_request_data,
|
||||
mock_route,
|
||||
mock_span,
|
||||
mock_api_key,
|
||||
)
|
||||
assert result.key_name == "failed-to-connect-to-db"
|
||||
assert result.token == "failed-to-connect-to-db"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_authentication_error_budget_exceeded():
|
||||
handler = UserAPIKeyAuthExceptionHandler()
|
||||
|
||||
# Mock request and other dependencies
|
||||
mock_request = MagicMock()
|
||||
mock_request_data = {}
|
||||
mock_route = "/test"
|
||||
mock_span = None
|
||||
mock_api_key = "test-key"
|
||||
|
||||
# Test with budget exceeded error
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
from litellm.exceptions import BudgetExceededError
|
||||
|
||||
budget_error = BudgetExceededError(
|
||||
message="Budget exceeded", current_cost=100, max_budget=100
|
||||
)
|
||||
await handler._handle_authentication_error(
|
||||
budget_error,
|
||||
mock_request,
|
||||
mock_request_data,
|
||||
mock_route,
|
||||
mock_span,
|
||||
mock_api_key,
|
||||
)
|
||||
|
||||
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
|
|
@ -13,9 +13,6 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest, litellm
|
||||
import httpx
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_handle_failed_db_connection_for_get_key_object,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -78,36 +75,6 @@ async def test_get_end_user_object(customer_spend, customer_budget):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_failed_db_connection():
|
||||
"""
|
||||
Test cases:
|
||||
1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth
|
||||
2. When allow_requests_on_db_unavailable=False -> raise original error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name
|
||||
|
||||
# Test case 1: allow_requests_on_db_unavailable=True
|
||||
general_settings["allow_requests_on_db_unavailable"] = True
|
||||
mock_error = httpx.ConnectError("Failed to connect to DB")
|
||||
|
||||
result = await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
||||
|
||||
assert isinstance(result, UserAPIKeyAuth)
|
||||
assert result.key_name == "failed-to-connect-to-db"
|
||||
assert result.token == "failed-to-connect-to-db"
|
||||
assert result.user_id == litellm_proxy_admin_name
|
||||
|
||||
# Test case 2: allow_requests_on_db_unavailable=False
|
||||
general_settings["allow_requests_on_db_unavailable"] = False
|
||||
|
||||
with pytest.raises(httpx.ConnectError) as exc_info:
|
||||
await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
||||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||
|
||||
assert str(exc_info.value) == "Failed to connect to DB"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue