diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 029430f09a..626ec513c9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -111,6 +111,9 @@ class ConfigGeneralSettings(BaseModel): custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth") max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key") infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)") + background_health_checks: Optional[bool] = Field(None, description="run health checks in background") + health_check_interval: int = Field(300, description="background health check interval in seconds") + class ConfigYAML(BaseModel): """ diff --git a/litellm/health_check.py b/litellm/proxy/health_check.py similarity index 100% rename from litellm/health_check.py rename to litellm/proxy/health_check.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b4763c11cf..bdd415f28d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -98,7 +98,7 @@ from litellm.proxy.utils import ( import pydantic from litellm.proxy._types import * from litellm.caching import DualCache -from litellm.health_check import perform_health_check +from litellm.proxy.health_check import perform_health_check litellm.suppress_debug_info = True from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi.routing import APIRouter @@ -193,6 +193,9 @@ otel_logging = False prisma_client: Optional[PrismaClient] = None user_api_key_cache = DualCache() user_custom_auth = None +use_background_health_checks = None +health_check_interval = None +health_check_results = {} ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -418,8 +421,26 @@ def run_ollama_serve(): LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` """) +async def _run_background_health_check(): + """ + Periodically run health checks in the background on the endpoints. + + Update health_check_results, based on this. + """ + global health_check_results, llm_model_list, health_check_interval + while True: + healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=llm_model_list) + + # Update the global variable with the health check results + health_check_results["healthy_endpoints"] = healthy_endpoints + health_check_results["unhealthy_endpoints"] = unhealthy_endpoints + health_check_results["healthy_count"] = len(healthy_endpoints) + health_check_results["unhealthy_count"] = len(unhealthy_endpoints) + + await asyncio.sleep(health_check_interval) + def load_router_config(router: Optional[litellm.Router], config_file_path: str): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval config = {} try: if os.path.exists(config_file_path): @@ -473,8 +494,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): custom_auth = general_settings.get("custom_auth", None) if custom_auth: user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path) + ### BACKGROUND HEALTH CHECKS ### + # Enable background health checks + use_background_health_checks = general_settings.get("background_health_checks", False) + health_check_interval = general_settings.get("health_check_interval", 300) + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) + if litellm_settings is None: + litellm_settings = {} if litellm_settings: # ANSI escape code for blue text blue_color_code = "\033[94m" @@ -711,7 +739,6 @@ def data_generator(response): except: yield f"data: {json.dumps(chunk)}\n\n" - async def async_data_generator(response): print_verbose("inside generator") async for chunk in response: @@ -797,7 +824,7 @@ async def rate_limit_per_token(request: Request, call_next): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key + global prisma_client, master_key, use_background_health_checks import json worker_config = litellm.get_secret("WORKER_CONFIG") @@ -810,6 +837,10 @@ async def startup_event(): # if not, assume it's a json string worker_config = json.loads(os.getenv("WORKER_CONFIG")) initialize(**worker_config) + + if use_background_health_checks: + asyncio.create_task(_run_background_health_check()) # start the background health check coroutine. + print_verbose(f"prisma client - {prisma_client}") if prisma_client: await prisma_client.connect() @@ -824,8 +855,11 @@ async def shutdown_event(): if prisma_client: print("Disconnecting from Prisma") await prisma_client.disconnect() + + ## RESET CUSTOM VARIABLES ## master_key = None user_custom_auth = None + #### API ENDPOINTS #### @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list @@ -1326,23 +1360,38 @@ async def config_yaml_endpoint(config_info: ConfigYAML): async def test_endpoint(request: Request): return {"route": request.url.path} -@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"], dependencies=[Depends(user_api_key_auth)]) +@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")): - global llm_model_list + """ + Check the health of all the endpoints in config.yaml + + To run health checks in the background, add this to config.yaml: + ``` + general_settings: + # ... other settings + background_health_checks: True + ``` + else, the health checks will be run on models when /health is called. + """ + global health_check_results, use_background_health_checks if llm_model_list is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error": "Model list not initialized"}, ) - healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model) + + if use_background_health_checks: + return health_check_results + else: + healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model) - return { - "healthy_endpoints": healthy_endpoints, - "unhealthy_endpoints": unhealthy_endpoints, - "healthy_count": len(healthy_endpoints), - "unhealthy_count": len(unhealthy_endpoints), - } + return { + "healthy_endpoints": healthy_endpoints, + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), + } @router.get("/") async def home(request: Request):