mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #9533 from BerriAI/litellm_stability_fixes
[Reliability Fixes] - Gracefully handle exceptions when DB is having an outage
This commit is contained in:
commit
0155b0eba2
10 changed files with 475 additions and 316 deletions
|
@ -1776,6 +1776,7 @@ response = completion(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="proxy" label="PROXY">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
1. Setup config.yaml
|
1. Setup config.yaml
|
||||||
|
@ -1820,11 +1821,13 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### SSO Login (AWS Profile)
|
### SSO Login (AWS Profile)
|
||||||
- Set `AWS_PROFILE` environment variable
|
- Set `AWS_PROFILE` environment variable
|
||||||
- Make bedrock completion call
|
- Make bedrock completion call
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -1940,9 +1943,6 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/images/generations' \
|
||||||
"colorGuidedGenerationParams":{"colors":["#FFFFFF"]}
|
"colorGuidedGenerationParams":{"colors":["#FFFFFF"]}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|-------------------------|---------------------------------------------|
|
|-------------------------|---------------------------------------------|
|
||||||
|
|
|
@ -160,7 +160,7 @@ general_settings:
|
||||||
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
|
| database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) |
|
||||||
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
|
| database_connection_pool_limit | integer | The limit for database connection pool [Setting DB Connection Pool limit](#configure-db-pool-limits--connection-timeouts) |
|
||||||
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
|
| database_connection_timeout | integer | The timeout for database connections in seconds [Setting DB Connection Pool limit, timeout](#configure-db-pool-limits--connection-timeouts) |
|
||||||
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key |
|
| allow_requests_on_db_unavailable | boolean | If true, allows requests to succeed even if DB is unreachable. **Only use this if running LiteLLM in your VPC** This will allow requests to work even when LiteLLM cannot connect to the DB to verify a Virtual Key [Doc on graceful db unavailability](prod#5-if-running-litellm-on-vpc-gracefully-handle-db-unavailability) |
|
||||||
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
|
| custom_auth | string | Write your own custom authentication logic [Doc Custom Auth](virtual_keys#custom-auth) |
|
||||||
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
|
| max_parallel_requests | integer | The max parallel requests allowed per deployment |
|
||||||
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
|
| global_max_parallel_requests | integer | The max parallel requests allowed on the proxy overall |
|
||||||
|
|
|
@ -94,15 +94,29 @@ This disables the load_dotenv() functionality, which will automatically load you
|
||||||
|
|
||||||
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
|
## 5. If running LiteLLM on VPC, gracefully handle DB unavailability
|
||||||
|
|
||||||
This will allow LiteLLM to continue to process requests even if the DB is unavailable. This is better handling for DB unavailability.
|
When running LiteLLM on a VPC (and inaccessible from the public internet), you can enable graceful degradation so that request processing continues even if the database is temporarily unavailable.
|
||||||
|
|
||||||
|
|
||||||
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
|
**WARNING: Only do this if you're running LiteLLM on VPC, that cannot be accessed from the public internet.**
|
||||||
|
|
||||||
```yaml
|
#### Configuration
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="litellm config.yaml"
|
||||||
general_settings:
|
general_settings:
|
||||||
allow_requests_on_db_unavailable: True
|
allow_requests_on_db_unavailable: True
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Expected Behavior
|
||||||
|
|
||||||
|
When `allow_requests_on_db_unavailable` is set to `true`, LiteLLM will handle errors as follows:
|
||||||
|
|
||||||
|
| Type of Error | Expected Behavior | Details |
|
||||||
|
|---------------|-------------------|----------------|
|
||||||
|
| Prisma Errors | ✅ Request will be allowed | Covers issues like DB connection resets or rejections from the DB via Prisma, the ORM used by LiteLLM. |
|
||||||
|
| Httpx Errors | ✅ Request will be allowed | Occurs when the database is unreachable, allowing the request to proceed despite the DB outage. |
|
||||||
|
| LiteLLM Budget Errors or Model Errors | ❌ Request will be blocked | Triggered when the DB is reachable but the authentication token is invalid, lacks access, or exceeds budget limits. |
|
||||||
|
|
||||||
|
|
||||||
## 6. Disable spend_logs & error_logs if not using the LiteLLM UI
|
## 6. Disable spend_logs & error_logs if not using the LiteLLM UI
|
||||||
|
|
||||||
By default, LiteLLM writes several types of logs to the database:
|
By default, LiteLLM writes several types of logs to the database:
|
||||||
|
@ -182,94 +196,4 @@ You should only see the following level of details in logs on the proxy server
|
||||||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Machine Specifications to Deploy LiteLLM
|
|
||||||
|
|
||||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
|
||||||
| --- | --- | --- | --- | --- | --- |
|
|
||||||
| Server | `t2.small`. | `1vCPUs` | `8GB` | `x86` |
|
|
||||||
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
|
||||||
|
|
||||||
|
|
||||||
### Reference Kubernetes Deployment YAML
|
|
||||||
|
|
||||||
Reference Kubernetes `deployment.yaml` that was load tested by us
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
apiVersion: apps/v1
|
|
||||||
kind: Deployment
|
|
||||||
metadata:
|
|
||||||
name: litellm-deployment
|
|
||||||
spec:
|
|
||||||
replicas: 3
|
|
||||||
selector:
|
|
||||||
matchLabels:
|
|
||||||
app: litellm
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
app: litellm
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: litellm-container
|
|
||||||
image: ghcr.io/berriai/litellm:main-latest
|
|
||||||
imagePullPolicy: Always
|
|
||||||
env:
|
|
||||||
- name: AZURE_API_KEY
|
|
||||||
value: "d6******"
|
|
||||||
- name: AZURE_API_BASE
|
|
||||||
value: "https://ope******"
|
|
||||||
- name: LITELLM_MASTER_KEY
|
|
||||||
value: "sk-1234"
|
|
||||||
- name: DATABASE_URL
|
|
||||||
value: "po**********"
|
|
||||||
args:
|
|
||||||
- "--config"
|
|
||||||
- "/app/proxy_config.yaml" # Update the path to mount the config file
|
|
||||||
volumeMounts: # Define volume mount for proxy_config.yaml
|
|
||||||
- name: config-volume
|
|
||||||
mountPath: /app
|
|
||||||
readOnly: true
|
|
||||||
livenessProbe:
|
|
||||||
httpGet:
|
|
||||||
path: /health/liveliness
|
|
||||||
port: 4000
|
|
||||||
initialDelaySeconds: 120
|
|
||||||
periodSeconds: 15
|
|
||||||
successThreshold: 1
|
|
||||||
failureThreshold: 3
|
|
||||||
timeoutSeconds: 10
|
|
||||||
readinessProbe:
|
|
||||||
httpGet:
|
|
||||||
path: /health/readiness
|
|
||||||
port: 4000
|
|
||||||
initialDelaySeconds: 120
|
|
||||||
periodSeconds: 15
|
|
||||||
successThreshold: 1
|
|
||||||
failureThreshold: 3
|
|
||||||
timeoutSeconds: 10
|
|
||||||
volumes: # Define volume to mount proxy_config.yaml
|
|
||||||
- name: config-volume
|
|
||||||
configMap:
|
|
||||||
name: litellm-config
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
Reference Kubernetes `service.yaml` that was load tested by us
|
|
||||||
```yaml
|
|
||||||
apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: litellm-service
|
|
||||||
spec:
|
|
||||||
selector:
|
|
||||||
app: litellm
|
|
||||||
ports:
|
|
||||||
- protocol: TCP
|
|
||||||
port: 4000
|
|
||||||
targetPort: 4000
|
|
||||||
type: LoadBalancer
|
|
||||||
```
|
|
|
@ -2067,16 +2067,68 @@ class SpendCalculateRequest(LiteLLMPydanticObjectBase):
|
||||||
|
|
||||||
class ProxyErrorTypes(str, enum.Enum):
|
class ProxyErrorTypes(str, enum.Enum):
|
||||||
budget_exceeded = "budget_exceeded"
|
budget_exceeded = "budget_exceeded"
|
||||||
|
"""
|
||||||
|
Object was over budget
|
||||||
|
"""
|
||||||
|
no_db_connection = "no_db_connection"
|
||||||
|
"""
|
||||||
|
No database connection
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_not_found_in_db = "token_not_found_in_db"
|
||||||
|
"""
|
||||||
|
Requested token was not found in the database
|
||||||
|
"""
|
||||||
|
|
||||||
key_model_access_denied = "key_model_access_denied"
|
key_model_access_denied = "key_model_access_denied"
|
||||||
|
"""
|
||||||
|
Key does not have access to the model
|
||||||
|
"""
|
||||||
|
|
||||||
team_model_access_denied = "team_model_access_denied"
|
team_model_access_denied = "team_model_access_denied"
|
||||||
|
"""
|
||||||
|
Team does not have access to the model
|
||||||
|
"""
|
||||||
|
|
||||||
user_model_access_denied = "user_model_access_denied"
|
user_model_access_denied = "user_model_access_denied"
|
||||||
|
"""
|
||||||
|
User does not have access to the model
|
||||||
|
"""
|
||||||
|
|
||||||
expired_key = "expired_key"
|
expired_key = "expired_key"
|
||||||
|
"""
|
||||||
|
Key has expired
|
||||||
|
"""
|
||||||
|
|
||||||
auth_error = "auth_error"
|
auth_error = "auth_error"
|
||||||
|
"""
|
||||||
|
General authentication error
|
||||||
|
"""
|
||||||
|
|
||||||
internal_server_error = "internal_server_error"
|
internal_server_error = "internal_server_error"
|
||||||
|
"""
|
||||||
|
Internal server error
|
||||||
|
"""
|
||||||
|
|
||||||
bad_request_error = "bad_request_error"
|
bad_request_error = "bad_request_error"
|
||||||
|
"""
|
||||||
|
Bad request error
|
||||||
|
"""
|
||||||
|
|
||||||
not_found_error = "not_found_error"
|
not_found_error = "not_found_error"
|
||||||
validation_error = "bad_request_error"
|
"""
|
||||||
|
Not found error
|
||||||
|
"""
|
||||||
|
|
||||||
|
validation_error = "validation_error"
|
||||||
|
"""
|
||||||
|
Validation error
|
||||||
|
"""
|
||||||
|
|
||||||
cache_ping_error = "cache_ping_error"
|
cache_ping_error = "cache_ping_error"
|
||||||
|
"""
|
||||||
|
Cache ping error
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_access_error_type_for_object(
|
def get_model_access_error_type_for_object(
|
||||||
|
@ -2093,7 +2145,11 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||||
return cls.user_model_access_denied
|
return cls.user_model_access_denied
|
||||||
|
|
||||||
|
|
||||||
DB_CONNECTION_ERROR_TYPES = (httpx.ConnectError, httpx.ReadError, httpx.ReadTimeout)
|
DB_CONNECTION_ERROR_TYPES = (
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ReadError,
|
||||||
|
httpx.ReadTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SSOUserDefinedValues(TypedDict):
|
class SSOUserDefinedValues(TypedDict):
|
||||||
|
|
|
@ -11,7 +11,6 @@ Run checks for:
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast
|
||||||
|
|
||||||
from fastapi import Request, status
|
from fastapi import Request, status
|
||||||
|
@ -23,7 +22,6 @@ from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
DB_CONNECTION_ERROR_TYPES,
|
|
||||||
RBAC_ROLES,
|
RBAC_ROLES,
|
||||||
CallInfo,
|
CallInfo,
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
|
@ -45,7 +43,6 @@ from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.route_llm_request import route_request
|
from litellm.proxy.route_llm_request import route_request
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||||
from litellm.router import Router
|
from litellm.router import Router
|
||||||
from litellm.types.services import ServiceTypes
|
|
||||||
|
|
||||||
from .auth_checks_organization import organization_role_based_access_check
|
from .auth_checks_organization import organization_role_based_access_check
|
||||||
|
|
||||||
|
@ -987,75 +984,34 @@ async def get_key_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
||||||
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
|
token=hashed_token,
|
||||||
token=hashed_token,
|
table_name="combined_view",
|
||||||
table_name="combined_view",
|
parent_otel_span=parent_otel_span,
|
||||||
parent_otel_span=parent_otel_span,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _valid_token is None:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
|
||||||
|
|
||||||
# save the key object to cache
|
|
||||||
await _cache_key_object(
|
|
||||||
hashed_token=hashed_token,
|
|
||||||
user_api_key_obj=_response,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _response
|
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
|
||||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise Exception(
|
|
||||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_failed_db_connection_for_get_key_object(
|
|
||||||
e: Exception,
|
|
||||||
) -> UserAPIKeyAuth:
|
|
||||||
"""
|
|
||||||
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB
|
|
||||||
|
|
||||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
- Orignal Exception in all other cases
|
|
||||||
"""
|
|
||||||
from litellm.proxy.proxy_server import (
|
|
||||||
general_settings,
|
|
||||||
litellm_proxy_admin_name,
|
|
||||||
proxy_logging_obj,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# If this flag is on, requests failing to connect to the DB will be allowed
|
if _valid_token is None:
|
||||||
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
raise ProxyException(
|
||||||
# log this as a DB failure on prometheus
|
message="Authentication Error, Invalid proxy server token passed. key={}, not found in db. Create key via `/key/generate` call.".format(
|
||||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
hashed_token
|
||||||
service=ServiceTypes.DB,
|
),
|
||||||
call_type="get_key_object",
|
type=ProxyErrorTypes.token_not_found_in_db,
|
||||||
error=e,
|
param="key",
|
||||||
duration=0.0,
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
return UserAPIKeyAuth(
|
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True))
|
||||||
key_name="failed-to-connect-to-db",
|
|
||||||
token="failed-to-connect-to-db",
|
# save the key object to cache
|
||||||
user_id=litellm_proxy_admin_name,
|
await _cache_key_object(
|
||||||
)
|
hashed_token=hashed_token,
|
||||||
else:
|
user_api_key_obj=_response,
|
||||||
# raise the original exception, the wrapper on `get_key_object` handles logging db failure to prometheus
|
user_api_key_cache=user_api_key_cache,
|
||||||
raise e
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response
|
||||||
|
|
||||||
|
|
||||||
@log_db_metrics
|
@log_db_metrics
|
||||||
|
|
153
litellm/proxy/auth/auth_exception_handler.py
Normal file
153
litellm/proxy/auth/auth_exception_handler.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
"""
|
||||||
|
Handles Authentication Errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
|
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||||
|
from litellm.types.services import ServiceTypes
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
Span = _Span
|
||||||
|
else:
|
||||||
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
class UserAPIKeyAuthExceptionHandler:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _handle_authentication_error(
|
||||||
|
e: Exception,
|
||||||
|
request: Request,
|
||||||
|
request_data: dict,
|
||||||
|
route: str,
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
api_key: str,
|
||||||
|
) -> UserAPIKeyAuth:
|
||||||
|
"""
|
||||||
|
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
|
||||||
|
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||||
|
|
||||||
|
Reliability scenarios this covers:
|
||||||
|
- DB is down and having an outage
|
||||||
|
- Unable to read / recover a key from the DB
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- Orignal Exception in all other cases
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
general_settings,
|
||||||
|
litellm_proxy_admin_name,
|
||||||
|
proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
UserAPIKeyAuthExceptionHandler.should_allow_request_on_db_unavailable()
|
||||||
|
and UserAPIKeyAuthExceptionHandler.is_database_connection_error(e)
|
||||||
|
):
|
||||||
|
# log this as a DB failure on prometheus
|
||||||
|
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||||
|
service=ServiceTypes.DB,
|
||||||
|
call_type="get_key_object",
|
||||||
|
error=e,
|
||||||
|
duration=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return UserAPIKeyAuth(
|
||||||
|
key_name="failed-to-connect-to-db",
|
||||||
|
token="failed-to-connect-to-db",
|
||||||
|
user_id=litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# raise the exception to the caller
|
||||||
|
requester_ip = _get_request_ip_address(
|
||||||
|
request=request,
|
||||||
|
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||||
|
str(e),
|
||||||
|
requester_ip,
|
||||||
|
),
|
||||||
|
extra={"requester_ip": requester_ip},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log this exception to OTEL, Datadog etc
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.post_call_failure_hook(
|
||||||
|
request_data=request_data,
|
||||||
|
original_exception=e,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
error_type=ProxyErrorTypes.auth_error,
|
||||||
|
route=route,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(e, litellm.BudgetExceededError):
|
||||||
|
raise ProxyException(
|
||||||
|
message=e.message,
|
||||||
|
type=ProxyErrorTypes.budget_exceeded,
|
||||||
|
param=None,
|
||||||
|
code=400,
|
||||||
|
)
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||||
|
)
|
||||||
|
elif isinstance(e, ProxyException):
|
||||||
|
raise e
|
||||||
|
raise ProxyException(
|
||||||
|
message="Authentication Error, " + str(e),
|
||||||
|
type=ProxyErrorTypes.auth_error,
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_allow_request_on_db_unavailable() -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the request should be allowed to proceed despite the DB connection error
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import general_settings
|
||||||
|
|
||||||
|
if general_settings.get("allow_requests_on_db_unavailable", False) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_database_connection_error(e: Exception) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the exception is from a database outage / connection error
|
||||||
|
"""
|
||||||
|
import prisma
|
||||||
|
|
||||||
|
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
|
||||||
|
return True
|
||||||
|
if isinstance(e, prisma.errors.PrismaError):
|
||||||
|
return True
|
||||||
|
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -26,7 +26,6 @@ from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
_get_user_role,
|
_get_user_role,
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
|
||||||
_is_user_proxy_admin,
|
_is_user_proxy_admin,
|
||||||
_virtual_key_max_budget_check,
|
_virtual_key_max_budget_check,
|
||||||
_virtual_key_soft_budget_check,
|
_virtual_key_soft_budget_check,
|
||||||
|
@ -38,8 +37,8 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_user_object,
|
get_user_object,
|
||||||
is_valid_fallback_model,
|
is_valid_fallback_model,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
||||||
from litellm.proxy.auth.auth_utils import (
|
from litellm.proxy.auth.auth_utils import (
|
||||||
_get_request_ip_address,
|
|
||||||
get_end_user_id_from_request_body,
|
get_end_user_id_from_request_body,
|
||||||
get_request_route,
|
get_request_route,
|
||||||
is_pass_through_provider_route,
|
is_pass_through_provider_route,
|
||||||
|
@ -675,8 +674,11 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
if (
|
if (
|
||||||
prisma_client is None
|
prisma_client is None
|
||||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
return await _handle_failed_db_connection_for_get_key_object(
|
raise ProxyException(
|
||||||
e=Exception("No connected db.")
|
message="No connected db.",
|
||||||
|
type=ProxyErrorTypes.no_db_connection,
|
||||||
|
code=400,
|
||||||
|
param=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
## check for cache hit (In-Memory Cache)
|
## check for cache hit (In-Memory Cache)
|
||||||
|
@ -685,37 +687,25 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
api_key = hash_token(token=api_key)
|
api_key = hash_token(token=api_key)
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
try:
|
valid_token = await get_key_object(
|
||||||
valid_token = await get_key_object(
|
hashed_token=api_key,
|
||||||
hashed_token=api_key,
|
prisma_client=prisma_client,
|
||||||
prisma_client=prisma_client,
|
user_api_key_cache=user_api_key_cache,
|
||||||
user_api_key_cache=user_api_key_cache,
|
parent_otel_span=parent_otel_span,
|
||||||
parent_otel_span=parent_otel_span,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
)
|
||||||
)
|
# update end-user params on valid token
|
||||||
# update end-user params on valid token
|
# These can change per request - it's important to update them here
|
||||||
# These can change per request - it's important to update them here
|
valid_token.end_user_id = end_user_params.get("end_user_id")
|
||||||
valid_token.end_user_id = end_user_params.get("end_user_id")
|
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
|
||||||
valid_token.end_user_tpm_limit = end_user_params.get(
|
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
|
||||||
"end_user_tpm_limit"
|
valid_token.allowed_model_region = end_user_params.get(
|
||||||
)
|
"allowed_model_region"
|
||||||
valid_token.end_user_rpm_limit = end_user_params.get(
|
)
|
||||||
"end_user_rpm_limit"
|
# update key budget with temp budget increase
|
||||||
)
|
valid_token = _update_key_budget_with_temp_budget_increase(
|
||||||
valid_token.allowed_model_region = end_user_params.get(
|
valid_token
|
||||||
"allowed_model_region"
|
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
||||||
)
|
|
||||||
# update key budget with temp budget increase
|
|
||||||
valid_token = _update_key_budget_with_temp_budget_increase(
|
|
||||||
valid_token
|
|
||||||
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
|
||||||
except Exception:
|
|
||||||
verbose_logger.info(
|
|
||||||
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
|
|
||||||
api_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
valid_token = None
|
|
||||||
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -1015,58 +1005,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
route=route,
|
route=route,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
requester_ip = _get_request_ip_address(
|
return await UserAPIKeyAuthExceptionHandler._handle_authentication_error(
|
||||||
|
e=e,
|
||||||
request=request,
|
request=request,
|
||||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
request_data=request_data,
|
||||||
)
|
route=route,
|
||||||
verbose_proxy_logger.exception(
|
|
||||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
|
||||||
str(e),
|
|
||||||
requester_ip,
|
|
||||||
),
|
|
||||||
extra={"requester_ip": requester_ip},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log this exception to OTEL, Datadog etc
|
|
||||||
user_api_key_dict = UserAPIKeyAuth(
|
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.post_call_failure_hook(
|
|
||||||
request_data=request_data,
|
|
||||||
original_exception=e,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
error_type=ProxyErrorTypes.auth_error,
|
|
||||||
route=route,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(e, litellm.BudgetExceededError):
|
|
||||||
raise ProxyException(
|
|
||||||
message=e.message,
|
|
||||||
type=ProxyErrorTypes.budget_exceeded,
|
|
||||||
param=None,
|
|
||||||
code=400,
|
|
||||||
)
|
|
||||||
if isinstance(e, HTTPException):
|
|
||||||
raise ProxyException(
|
|
||||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
|
||||||
)
|
|
||||||
elif isinstance(e, ProxyException):
|
|
||||||
raise e
|
|
||||||
raise ProxyException(
|
|
||||||
message="Authentication Error, " + str(e),
|
|
||||||
type=ProxyErrorTypes.auth_error,
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tracer.wrap()
|
@tracer.wrap()
|
||||||
|
|
|
@ -1,37 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo-end-user-test
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: openai/fake
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_key: fake-key
|
||||||
api_version: "2023-05-15"
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
allow_requests_on_db_unavailable: True
|
||||||
mcp_tools:
|
|
||||||
- name: "get_current_time"
|
|
||||||
description: "Get the current time"
|
|
||||||
input_schema: {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"format": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The format of the time to return",
|
|
||||||
"enum": ["short"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
handler: "mcp_tools.get_current_time"
|
|
||||||
- name: "get_current_date"
|
|
||||||
description: "Get the current date"
|
|
||||||
input_schema: {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"format": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The format of the date to return",
|
|
||||||
"enum": ["short"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
handler: "mcp_tools.get_current_date"
|
|
184
tests/litellm/proxy/auth/test_auth_exception_handler.py
Normal file
184
tests/litellm/proxy/auth/test_auth_exception_handler.py
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException, Request, status
|
||||||
|
from prisma import errors as prisma_errors
|
||||||
|
from prisma.errors import (
|
||||||
|
ClientNotConnectedError,
|
||||||
|
DataError,
|
||||||
|
ForeignKeyViolationError,
|
||||||
|
HTTPClientClosedError,
|
||||||
|
MissingRequiredValueError,
|
||||||
|
PrismaError,
|
||||||
|
RawQueryError,
|
||||||
|
RecordNotFoundError,
|
||||||
|
TableNotFoundError,
|
||||||
|
UniqueViolationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||||||
|
from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler
|
||||||
|
|
||||||
|
|
||||||
|
# Test is_database_connection_error method
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prisma_error",
|
||||||
|
[
|
||||||
|
PrismaError(),
|
||||||
|
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||||
|
UniqueViolationError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
ForeignKeyViolationError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
MissingRequiredValueError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||||
|
TableNotFoundError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
RecordNotFoundError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
HTTPClientClosedError(),
|
||||||
|
ClientNotConnectedError(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_database_connection_error_prisma_errors(prisma_error):
|
||||||
|
"""
|
||||||
|
Test that all Prisma errors are considered database connection errors
|
||||||
|
"""
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
assert handler.is_database_connection_error(prisma_error) == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_database_connection_generic_errors():
|
||||||
|
"""
|
||||||
|
Test non-Prisma error cases for database connection checking
|
||||||
|
"""
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
|
||||||
|
# Test with ProxyException (DB connection)
|
||||||
|
db_proxy_exception = ProxyException(
|
||||||
|
message="DB Connection Error",
|
||||||
|
type=ProxyErrorTypes.no_db_connection,
|
||||||
|
param="test-param",
|
||||||
|
)
|
||||||
|
assert handler.is_database_connection_error(db_proxy_exception) == True
|
||||||
|
|
||||||
|
# Test with non-DB error
|
||||||
|
regular_exception = Exception("Regular error")
|
||||||
|
assert handler.is_database_connection_error(regular_exception) == False
|
||||||
|
|
||||||
|
|
||||||
|
# Test should_allow_request_on_db_unavailable method
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"allow_requests_on_db_unavailable": True},
|
||||||
|
)
|
||||||
|
def test_should_allow_request_on_db_unavailable_true():
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
assert handler.should_allow_request_on_db_unavailable() == True
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"allow_requests_on_db_unavailable": False},
|
||||||
|
)
|
||||||
|
def test_should_allow_request_on_db_unavailable_false():
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
assert handler.should_allow_request_on_db_unavailable() == False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prisma_error",
|
||||||
|
[
|
||||||
|
PrismaError(),
|
||||||
|
DataError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||||
|
UniqueViolationError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
ForeignKeyViolationError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
MissingRequiredValueError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
RawQueryError(data={"user_facing_error": {"meta": {"table": "test_table"}}}),
|
||||||
|
TableNotFoundError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
RecordNotFoundError(
|
||||||
|
data={"user_facing_error": {"meta": {"table": "test_table"}}}
|
||||||
|
),
|
||||||
|
HTTPClientClosedError(),
|
||||||
|
ClientNotConnectedError(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_handle_authentication_error_db_unavailable(prisma_error):
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
|
||||||
|
# Mock request and other dependencies
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request_data = {}
|
||||||
|
mock_route = "/test"
|
||||||
|
mock_span = None
|
||||||
|
mock_api_key = "test-key"
|
||||||
|
|
||||||
|
# Test with DB connection error when requests are allowed
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"allow_requests_on_db_unavailable": True},
|
||||||
|
):
|
||||||
|
result = await handler._handle_authentication_error(
|
||||||
|
prisma_error,
|
||||||
|
mock_request,
|
||||||
|
mock_request_data,
|
||||||
|
mock_route,
|
||||||
|
mock_span,
|
||||||
|
mock_api_key,
|
||||||
|
)
|
||||||
|
assert result.key_name == "failed-to-connect-to-db"
|
||||||
|
assert result.token == "failed-to-connect-to-db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_authentication_error_budget_exceeded():
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
|
||||||
|
# Mock request and other dependencies
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request_data = {}
|
||||||
|
mock_route = "/test"
|
||||||
|
mock_span = None
|
||||||
|
mock_api_key = "test-key"
|
||||||
|
|
||||||
|
# Test with budget exceeded error
|
||||||
|
with pytest.raises(ProxyException) as exc_info:
|
||||||
|
from litellm.exceptions import BudgetExceededError
|
||||||
|
|
||||||
|
budget_error = BudgetExceededError(
|
||||||
|
message="Budget exceeded", current_cost=100, max_budget=100
|
||||||
|
)
|
||||||
|
await handler._handle_authentication_error(
|
||||||
|
budget_error,
|
||||||
|
mock_request,
|
||||||
|
mock_request_data,
|
||||||
|
mock_route,
|
||||||
|
mock_span,
|
||||||
|
mock_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
|
|
@ -13,9 +13,6 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, litellm
|
import pytest, litellm
|
||||||
import httpx
|
import httpx
|
||||||
from litellm.proxy.auth.auth_checks import (
|
|
||||||
_handle_failed_db_connection_for_get_key_object,
|
|
||||||
)
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.auth_checks import get_end_user_object
|
from litellm.proxy.auth.auth_checks import get_end_user_object
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
@ -78,36 +75,6 @@ async def test_get_end_user_object(customer_spend, customer_budget):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_failed_db_connection():
|
|
||||||
"""
|
|
||||||
Test cases:
|
|
||||||
1. When allow_requests_on_db_unavailable=True -> return UserAPIKeyAuth
|
|
||||||
2. When allow_requests_on_db_unavailable=False -> raise original error
|
|
||||||
"""
|
|
||||||
from litellm.proxy.proxy_server import general_settings, litellm_proxy_admin_name
|
|
||||||
|
|
||||||
# Test case 1: allow_requests_on_db_unavailable=True
|
|
||||||
general_settings["allow_requests_on_db_unavailable"] = True
|
|
||||||
mock_error = httpx.ConnectError("Failed to connect to DB")
|
|
||||||
|
|
||||||
result = await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
|
||||||
|
|
||||||
assert isinstance(result, UserAPIKeyAuth)
|
|
||||||
assert result.key_name == "failed-to-connect-to-db"
|
|
||||||
assert result.token == "failed-to-connect-to-db"
|
|
||||||
assert result.user_id == litellm_proxy_admin_name
|
|
||||||
|
|
||||||
# Test case 2: allow_requests_on_db_unavailable=False
|
|
||||||
general_settings["allow_requests_on_db_unavailable"] = False
|
|
||||||
|
|
||||||
with pytest.raises(httpx.ConnectError) as exc_info:
|
|
||||||
await _handle_failed_db_connection_for_get_key_object(e=mock_error)
|
|
||||||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
|
||||||
|
|
||||||
assert str(exc_info.value) == "Failed to connect to DB"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, expect_to_work",
|
"model, expect_to_work",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue