feat(proxy_server.py): enable background health checks

This commit is contained in:
Krrish Dholakia 2023-12-07 19:40:06 -08:00
parent b8b15435b7
commit 9cf3051ea2
3 changed files with 65 additions and 13 deletions

View file

@ -98,7 +98,7 @@ from litellm.proxy.utils import (
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.health_check import perform_health_check
from litellm.proxy.health_check import perform_health_check
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi.routing import APIRouter
@ -193,6 +193,9 @@ otel_logging = False
prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
use_background_health_checks = None
health_check_interval = None
health_check_results = {}
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -418,8 +421,26 @@ def run_ollama_serve():
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
async def _run_background_health_check():
"""
Periodically run health checks in the background on the endpoints.
Update health_check_results, based on this.
"""
global health_check_results, llm_model_list, health_check_interval
while True:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=llm_model_list)
# Update the global variable with the health check results
health_check_results["healthy_endpoints"] = healthy_endpoints
health_check_results["unhealthy_endpoints"] = unhealthy_endpoints
health_check_results["healthy_count"] = len(healthy_endpoints)
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
config = {}
try:
if os.path.exists(config_file_path):
@ -473,8 +494,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get("background_health_checks", False)
health_check_interval = general_settings.get("health_check_interval", 300)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
if litellm_settings is None:
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
@ -711,7 +739,6 @@ def data_generator(response):
except:
yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response):
print_verbose("inside generator")
async for chunk in response:
@ -797,7 +824,7 @@ async def rate_limit_per_token(request: Request, call_next):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key
global prisma_client, master_key, use_background_health_checks
import json
worker_config = litellm.get_secret("WORKER_CONFIG")
@ -810,6 +837,10 @@ async def startup_event():
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config)
if use_background_health_checks:
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
print_verbose(f"prisma client - {prisma_client}")
if prisma_client:
await prisma_client.connect()
@ -824,8 +855,11 @@ async def shutdown_event():
if prisma_client:
print("Disconnecting from Prisma")
await prisma_client.disconnect()
## RESET CUSTOM VARIABLES ##
master_key = None
user_custom_auth = None
#### API ENDPOINTS ####
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
@ -1326,23 +1360,38 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
async def test_endpoint(request: Request):
return {"route": request.url.path}
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"], dependencies=[Depends(user_api_key_auth)])
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
global llm_model_list
"""
Check the health of all the endpoints in config.yaml
To run health checks in the background, add this to config.yaml:
```
general_settings:
# ... other settings
background_health_checks: True
```
else, the health checks will be run on models when /health is called.
"""
global health_check_results, use_background_health_checks
if llm_model_list is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "Model list not initialized"},
)
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
if use_background_health_checks:
return health_check_results
else:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}
return {
"healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
}
@router.get("/")
async def home(request: Request):