mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(proxy_server.py): initializing sentry in proxy logging before db init
This commit is contained in:
parent
a07659a7c2
commit
8460924f1d
2 changed files with 48 additions and 45 deletions
|
@ -303,7 +303,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client, proxy_logging_obj
|
global prisma_client, proxy_logging_obj
|
||||||
if database_url is not None and proxy_logging_obj is not None:
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
|
proxy_logging_obj = ProxyLogging()
|
||||||
|
|
||||||
|
if database_url is not None:
|
||||||
try:
|
try:
|
||||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -474,41 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
for key, value in environment_variables.items():
|
for key, value in environment_variables.items():
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
|
||||||
general_settings = config.get("general_settings", {})
|
|
||||||
if general_settings is None:
|
|
||||||
general_settings = {}
|
|
||||||
if general_settings:
|
|
||||||
### LOAD FROM AZURE KEY VAULT ###
|
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
|
||||||
### CONNECT TO DATABASE ###
|
|
||||||
database_url = general_settings.get("database_url", None)
|
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
|
||||||
database_url = litellm.get_secret(database_url)
|
|
||||||
prisma_setup(database_url=database_url)
|
|
||||||
## COST TRACKING ##
|
|
||||||
cost_tracking()
|
|
||||||
### START REDIS QUEUE ###
|
|
||||||
use_queue = general_settings.get("use_queue", False)
|
|
||||||
celery_setup(use_queue=use_queue)
|
|
||||||
### MASTER KEY ###
|
|
||||||
master_key = general_settings.get("master_key", None)
|
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
|
||||||
master_key = litellm.get_secret(master_key)
|
|
||||||
#### OpenTelemetry Logging (OTEL) ########
|
|
||||||
otel_logging = general_settings.get("otel", False)
|
|
||||||
if otel_logging == True:
|
|
||||||
print("\nOpenTelemetry Logging Activated")
|
|
||||||
### CUSTOM API KEY AUTH ###
|
|
||||||
custom_auth = general_settings.get("custom_auth", None)
|
|
||||||
if custom_auth:
|
|
||||||
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
|
||||||
### BACKGROUND HEALTH CHECKS ###
|
|
||||||
# Enable background health checks
|
|
||||||
use_background_health_checks = general_settings.get("background_health_checks", False)
|
|
||||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
if litellm_settings is None:
|
if litellm_settings is None:
|
||||||
|
@ -571,6 +539,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
else:
|
else:
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
|
|
||||||
|
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||||
|
general_settings = config.get("general_settings", {})
|
||||||
|
if general_settings is None:
|
||||||
|
general_settings = {}
|
||||||
|
if general_settings:
|
||||||
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
|
### CONNECT TO DATABASE ###
|
||||||
|
database_url = general_settings.get("database_url", None)
|
||||||
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
|
database_url = litellm.get_secret(database_url)
|
||||||
|
prisma_setup(database_url=database_url)
|
||||||
|
## COST TRACKING ##
|
||||||
|
cost_tracking()
|
||||||
|
### START REDIS QUEUE ###
|
||||||
|
use_queue = general_settings.get("use_queue", False)
|
||||||
|
celery_setup(use_queue=use_queue)
|
||||||
|
### MASTER KEY ###
|
||||||
|
master_key = general_settings.get("master_key", None)
|
||||||
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
|
master_key = litellm.get_secret(master_key)
|
||||||
|
#### OpenTelemetry Logging (OTEL) ########
|
||||||
|
otel_logging = general_settings.get("otel", False)
|
||||||
|
if otel_logging == True:
|
||||||
|
print("\nOpenTelemetry Logging Activated")
|
||||||
|
### CUSTOM API KEY AUTH ###
|
||||||
|
custom_auth = general_settings.get("custom_auth", None)
|
||||||
|
if custom_auth:
|
||||||
|
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||||
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
|
# Enable background health checks
|
||||||
|
use_background_health_checks = general_settings.get("background_health_checks", False)
|
||||||
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||||
|
|
||||||
## MODEL LIST
|
## MODEL LIST
|
||||||
model_list = config.get('model_list', None)
|
model_list = config.get('model_list', None)
|
||||||
if model_list:
|
if model_list:
|
||||||
|
@ -841,10 +844,8 @@ async def rate_limit_per_token(request: Request, call_next):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key, use_background_health_checks, proxy_logging_obj
|
global prisma_client, master_key, use_background_health_checks
|
||||||
import json
|
import json
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
|
||||||
proxy_logging_obj = ProxyLogging()
|
|
||||||
|
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||||
|
@ -857,6 +858,7 @@ async def startup_event():
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
|
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
||||||
|
|
||||||
|
|
|
@ -31,11 +31,12 @@ class ProxyLogging:
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback)
|
||||||
if callback not in litellm._async_failure_callback:
|
if callback not in litellm._async_failure_callback:
|
||||||
litellm._async_failure_callback.append(callback)
|
litellm._async_failure_callback.append(callback)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(litellm.input_callback) > 0
|
len(litellm.input_callback) > 0
|
||||||
or len(litellm.success_callback) > 0
|
or len(litellm.success_callback) > 0
|
||||||
or len(litellm.failure_callback) > 0
|
or len(litellm.failure_callback) > 0
|
||||||
) and len(callback_list) == 0:
|
):
|
||||||
callback_list = list(
|
callback_list = list(
|
||||||
set(
|
set(
|
||||||
litellm.input_callback
|
litellm.input_callback
|
||||||
|
@ -59,7 +60,6 @@ class ProxyLogging:
|
||||||
|
|
||||||
Currently only logs exceptions to sentry
|
Currently only logs exceptions to sentry
|
||||||
"""
|
"""
|
||||||
print(f"reaches failure handler logging - {original_exception}; sentry: {litellm.utils.capture_exception}")
|
|
||||||
if litellm.utils.capture_exception:
|
if litellm.utils.capture_exception:
|
||||||
litellm.utils.capture_exception(error=original_exception)
|
litellm.utils.capture_exception(error=original_exception)
|
||||||
|
|
||||||
|
@ -68,6 +68,9 @@ class ProxyLogging:
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||||
|
## init logging object
|
||||||
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
# Save the current working directory
|
# Save the current working directory
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
|
@ -85,8 +88,7 @@ class PrismaClient:
|
||||||
from prisma import Client # type: ignore
|
from prisma import Client # type: ignore
|
||||||
self.db = Client() #Client to connect to Prisma db
|
self.db = Client() #Client to connect to Prisma db
|
||||||
|
|
||||||
## init logging object
|
|
||||||
self.proxy_logging_obj = proxy_logging_obj
|
|
||||||
|
|
||||||
def hash_token(self, token: str):
|
def hash_token(self, token: str):
|
||||||
# Hash the string using SHA-256
|
# Hash the string using SHA-256
|
||||||
|
@ -122,7 +124,6 @@ class PrismaClient:
|
||||||
token = data["token"]
|
token = data["token"]
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
data["token"] = hashed_token
|
data["token"] = hashed_token
|
||||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
|
||||||
|
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue