forked from phoenix/litellm-mirror
feat(proxy_server.py): enable background health checks
This commit is contained in:
parent
b8b15435b7
commit
9cf3051ea2
3 changed files with 65 additions and 13 deletions
|
@ -111,6 +111,9 @@ class ConfigGeneralSettings(BaseModel):
|
||||||
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
|
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
|
||||||
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
|
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
|
||||||
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
|
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
|
||||||
|
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
|
||||||
|
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||||
|
|
||||||
|
|
||||||
class ConfigYAML(BaseModel):
|
class ConfigYAML(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -98,7 +98,7 @@ from litellm.proxy.utils import (
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
@ -193,6 +193,9 @@ otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
|
use_background_health_checks = None
|
||||||
|
health_check_interval = None
|
||||||
|
health_check_results = {}
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -418,8 +421,26 @@ def run_ollama_serve():
|
||||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
async def _run_background_health_check():
|
||||||
|
"""
|
||||||
|
Periodically run health checks in the background on the endpoints.
|
||||||
|
|
||||||
|
Update health_check_results, based on this.
|
||||||
|
"""
|
||||||
|
global health_check_results, llm_model_list, health_check_interval
|
||||||
|
while True:
|
||||||
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(model_list=llm_model_list)
|
||||||
|
|
||||||
|
# Update the global variable with the health check results
|
||||||
|
health_check_results["healthy_endpoints"] = healthy_endpoints
|
||||||
|
health_check_results["unhealthy_endpoints"] = unhealthy_endpoints
|
||||||
|
health_check_results["healthy_count"] = len(healthy_endpoints)
|
||||||
|
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
||||||
|
|
||||||
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
||||||
config = {}
|
config = {}
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
|
@ -473,8 +494,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
custom_auth = general_settings.get("custom_auth", None)
|
custom_auth = general_settings.get("custom_auth", None)
|
||||||
if custom_auth:
|
if custom_auth:
|
||||||
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||||
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
|
# Enable background health checks
|
||||||
|
use_background_health_checks = general_settings.get("background_health_checks", False)
|
||||||
|
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
|
if litellm_settings is None:
|
||||||
|
litellm_settings = {}
|
||||||
if litellm_settings:
|
if litellm_settings:
|
||||||
# ANSI escape code for blue text
|
# ANSI escape code for blue text
|
||||||
blue_color_code = "\033[94m"
|
blue_color_code = "\033[94m"
|
||||||
|
@ -711,7 +739,6 @@ def data_generator(response):
|
||||||
except:
|
except:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def async_data_generator(response):
|
async def async_data_generator(response):
|
||||||
print_verbose("inside generator")
|
print_verbose("inside generator")
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
@ -797,7 +824,7 @@ async def rate_limit_per_token(request: Request, call_next):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key
|
global prisma_client, master_key, use_background_health_checks
|
||||||
import json
|
import json
|
||||||
|
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||||
|
@ -810,6 +837,10 @@ async def startup_event():
|
||||||
# if not, assume it's a json string
|
# if not, assume it's a json string
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
|
|
||||||
|
if use_background_health_checks:
|
||||||
|
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
||||||
|
|
||||||
print_verbose(f"prisma client - {prisma_client}")
|
print_verbose(f"prisma client - {prisma_client}")
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
await prisma_client.connect()
|
await prisma_client.connect()
|
||||||
|
@ -824,8 +855,11 @@ async def shutdown_event():
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
print("Disconnecting from Prisma")
|
print("Disconnecting from Prisma")
|
||||||
await prisma_client.disconnect()
|
await prisma_client.disconnect()
|
||||||
|
|
||||||
|
## RESET CUSTOM VARIABLES ##
|
||||||
master_key = None
|
master_key = None
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||||
|
@ -1326,15 +1360,30 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
|
||||||
async def test_endpoint(request: Request):
|
async def test_endpoint(request: Request):
|
||||||
return {"route": request.url.path}
|
return {"route": request.url.path}
|
||||||
|
|
||||||
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
|
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
|
||||||
global llm_model_list
|
"""
|
||||||
|
Check the health of all the endpoints in config.yaml
|
||||||
|
|
||||||
|
To run health checks in the background, add this to config.yaml:
|
||||||
|
```
|
||||||
|
general_settings:
|
||||||
|
# ... other settings
|
||||||
|
background_health_checks: True
|
||||||
|
```
|
||||||
|
else, the health checks will be run on models when /health is called.
|
||||||
|
"""
|
||||||
|
global health_check_results, use_background_health_checks
|
||||||
|
|
||||||
if llm_model_list is None:
|
if llm_model_list is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail={"error": "Model list not initialized"},
|
detail={"error": "Model list not initialized"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_background_health_checks:
|
||||||
|
return health_check_results
|
||||||
|
else:
|
||||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue