Merge pull request #1023 from PSU3D0/speedup_health_endpoint

(feat) Speedup health endpoint
This commit is contained in:
Ishaan Jaff 2023-12-06 09:52:13 -08:00 committed by GitHub
commit f3d8825290
4 changed files with 129 additions and 27 deletions

View file

@ -2,6 +2,7 @@
import threading, requests import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose
import httpx import httpx
input_callback: List[Union[str, Callable]] = [] input_callback: List[Union[str, Callable]] = []
@ -11,7 +12,6 @@ callbacks: List[Callable] = []
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
set_verbose = False
email: Optional[ email: Optional[
str str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648

5
litellm/_logging.py Normal file
View file

@ -0,0 +1,5 @@
set_verbose = False
def print_verbose(print_statement):
if set_verbose:
print(print_statement)

115
litellm/health_check.py Normal file
View file

@ -0,0 +1,115 @@
import asyncio
import random
from typing import Optional
import litellm
import logging
from litellm._logging import print_verbose
logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]
def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]
return [
{"role": "user", "content": random.choice(messages)}
]
def _clean_litellm_params(litellm_params: dict):
"""
Clean the litellm params for display to users.
"""
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}
async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False
return True
prepped_params = []
for model in model_list:
litellm_params = model["litellm_params"]
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
litellm_params["messages"] = _get_random_llm_message()
prepped_params.append(litellm_params)
tasks = [_check_model(x) for x in prepped_params]
results = await asyncio.gather(*tasks)
healthy_endpoints = []
unhealthy_endpoints = []
for is_healthy, model in zip(results, model_list):
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
if is_healthy:
healthy_endpoints.append(cleaned_litellm_params)
else:
unhealthy_endpoints.append(cleaned_litellm_params)
return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(model_list: list, model: Optional[str] = None):
"""
Perform a health check on the system.
Returns:
(bool): True if the health check passes, False otherwise.
"""
if not model_list:
return [], []
if model is not None:
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]
models_to_check = []
for model in model_list:
litellm_params = model["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"])
if model_name in litellm.all_embedding_models:
continue # Skip embedding models
models_to_check.append(model)
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints

View file

@ -7,6 +7,8 @@ import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
import importlib import importlib
from litellm.health_check import perform_health_check
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -1173,34 +1175,14 @@ async def test_endpoint(request: Request):
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"]) @router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"])
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")): async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
global llm_model_list global llm_model_list
healthy_endpoints = []
unhealthy_endpoints = [] healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
if llm_model_list:
for model_name in llm_model_list:
try:
if model is None or model == model_name["litellm_params"]["model"]: # if model specified, just call that one.
litellm_params = model_name["litellm_params"]
model_name = litellm.utils.remove_model_id(litellm_params["model"]) # removes, ids set by litellm.router
if model_name not in litellm.all_embedding_models: # filter out embedding models
litellm_params["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
litellm_params["model"] = model_name
litellm.completion(**litellm_params)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
healthy_endpoints.append(cleaned_params)
except Exception as e:
print("Got Exception", e)
cleaned_params = {}
for key in litellm_params:
if key != "api_key" and key != "messages":
cleaned_params[key] = litellm_params[key]
unhealthy_endpoints.append(cleaned_params)
pass
return { return {
"healthy_endpoints": healthy_endpoints, "healthy_endpoints": healthy_endpoints,
"unhealthy_endpoints": unhealthy_endpoints "unhealthy_endpoints": unhealthy_endpoints,
"healthy_count": len(healthy_endpoints),
"unhealthy_count": len(unhealthy_endpoints),
} }
@router.get("/") @router.get("/")