forked from phoenix/litellm-mirror
feat(proxy_server.py): enable background health checks
This commit is contained in:
parent
b8b15435b7
commit
9cf3051ea2
3 changed files with 65 additions and 13 deletions
|
@ -98,7 +98,7 @@ from litellm.proxy.utils import (
|
|||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache
|
||||
from litellm.health_check import perform_health_check
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
from fastapi.routing import APIRouter
|
||||
|
@ -193,6 +193,9 @@ otel_logging = False
|
|||
prisma_client: Optional[PrismaClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -418,8 +421,26 @@ def run_ollama_serve():
|
|||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""")
|
||||
|
||||
async def _run_background_health_check():
|
||||
"""
|
||||
Periodically run health checks in the background on the endpoints.
|
||||
|
||||
Update health_check_results, based on this.
|
||||
"""
|
||||
global health_check_results, llm_model_list, health_check_interval
|
||||
while True:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=llm_model_list)
|
||||
|
||||
# Update the global variable with the health check results
|
||||
health_check_results["healthy_endpoints"] = healthy_endpoints
|
||||
health_check_results["unhealthy_endpoints"] = unhealthy_endpoints
|
||||
health_check_results["healthy_count"] = len(healthy_endpoints)
|
||||
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
||||
|
||||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
|
@ -473,8 +494,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||
### BACKGROUND HEALTH CHECKS ###
|
||||
# Enable background health checks
|
||||
use_background_health_checks = general_settings.get("background_health_checks", False)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
if litellm_settings is None:
|
||||
litellm_settings = {}
|
||||
if litellm_settings:
|
||||
# ANSI escape code for blue text
|
||||
blue_color_code = "\033[94m"
|
||||
|
@ -711,7 +739,6 @@ def data_generator(response):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator(response):
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
|
@ -797,7 +824,7 @@ async def rate_limit_per_token(request: Request, call_next):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
import json
|
||||
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
|
@ -810,6 +837,10 @@ async def startup_event():
|
|||
# if not, assume it's a json string
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
initialize(**worker_config)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
||||
|
||||
print_verbose(f"prisma client - {prisma_client}")
|
||||
if prisma_client:
|
||||
await prisma_client.connect()
|
||||
|
@ -824,8 +855,11 @@ async def shutdown_event():
|
|||
if prisma_client:
|
||||
print("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
master_key = None
|
||||
user_custom_auth = None
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||
|
@ -1326,23 +1360,38 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
|
|||
async def test_endpoint(request: Request):
|
||||
return {"route": request.url.path}
|
||||
|
||||
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
|
||||
global llm_model_list
|
||||
"""
|
||||
Check the health of all the endpoints in config.yaml
|
||||
|
||||
To run health checks in the background, add this to config.yaml:
|
||||
```
|
||||
general_settings:
|
||||
# ... other settings
|
||||
background_health_checks: True
|
||||
```
|
||||
else, the health checks will be run on models when /health is called.
|
||||
"""
|
||||
global health_check_results, use_background_health_checks
|
||||
|
||||
if llm_model_list is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Model list not initialized"},
|
||||
)
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
|
||||
|
||||
if use_background_health_checks:
|
||||
return health_check_results
|
||||
else:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
|
||||
|
||||
return {
|
||||
"healthy_endpoints": healthy_endpoints,
|
||||
"unhealthy_endpoints": unhealthy_endpoints,
|
||||
"healthy_count": len(healthy_endpoints),
|
||||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
return {
|
||||
"healthy_endpoints": healthy_endpoints,
|
||||
"unhealthy_endpoints": unhealthy_endpoints,
|
||||
"healthy_count": len(healthy_endpoints),
|
||||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
|
||||
@router.get("/")
|
||||
async def home(request: Request):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue