mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #1023 from PSU3D0/speedup_health_endpoint
(feat) Speedup health endpoint
This commit is contained in:
commit
f3d8825290
4 changed files with 129 additions and 27 deletions
|
@ -2,6 +2,7 @@
|
||||||
import threading, requests
|
import threading, requests
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any
|
from typing import Callable, List, Optional, Dict, Union, Any
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
|
from litellm._logging import set_verbose
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
input_callback: List[Union[str, Callable]] = []
|
input_callback: List[Union[str, Callable]] = []
|
||||||
|
@ -11,7 +12,6 @@ callbacks: List[Callable] = []
|
||||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
set_verbose = False
|
|
||||||
email: Optional[
|
email: Optional[
|
||||||
str
|
str
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
|
|
5
litellm/_logging.py
Normal file
5
litellm/_logging.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
set_verbose = False
|
||||||
|
|
||||||
|
def print_verbose(print_statement):
|
||||||
|
if set_verbose:
|
||||||
|
print(print_statement)
|
115
litellm/health_check.py
Normal file
115
litellm/health_check.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import logging
|
||||||
|
from litellm._logging import print_verbose
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ILLEGAL_DISPLAY_PARAMS = [
|
||||||
|
"messages",
|
||||||
|
"api_key"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_random_llm_message():
|
||||||
|
"""
|
||||||
|
Get a random message from the LLM.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
"Hey how's it going?",
|
||||||
|
"What's 1 + 1?"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
return [
|
||||||
|
{"role": "user", "content": random.choice(messages)}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_litellm_params(litellm_params: dict):
|
||||||
|
"""
|
||||||
|
Clean the litellm params for display to users.
|
||||||
|
"""
|
||||||
|
return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}
|
||||||
|
|
||||||
|
|
||||||
|
async def _perform_health_check(model_list: list):
|
||||||
|
"""
|
||||||
|
Perform a health check for each model in the list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _check_model(model_params: dict):
|
||||||
|
try:
|
||||||
|
await litellm.acompletion(**model_params)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
prepped_params = []
|
||||||
|
|
||||||
|
for model in model_list:
|
||||||
|
litellm_params = model["litellm_params"]
|
||||||
|
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
|
||||||
|
litellm_params["messages"] = _get_random_llm_message()
|
||||||
|
|
||||||
|
prepped_params.append(litellm_params)
|
||||||
|
|
||||||
|
|
||||||
|
tasks = [_check_model(x) for x in prepped_params]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
healthy_endpoints = []
|
||||||
|
unhealthy_endpoints = []
|
||||||
|
|
||||||
|
for is_healthy, model in zip(results, model_list):
|
||||||
|
cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])
|
||||||
|
|
||||||
|
if is_healthy:
|
||||||
|
healthy_endpoints.append(cleaned_litellm_params)
|
||||||
|
else:
|
||||||
|
unhealthy_endpoints.append(cleaned_litellm_params)
|
||||||
|
|
||||||
|
return healthy_endpoints, unhealthy_endpoints
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def perform_health_check(model_list: list, model: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Perform a health check on the system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool): True if the health check passes, False otherwise.
|
||||||
|
"""
|
||||||
|
if not model_list:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
model_list = [x for x in model_list if x["litellm_params"]["model"] == model]
|
||||||
|
|
||||||
|
models_to_check = []
|
||||||
|
|
||||||
|
for model in model_list:
|
||||||
|
litellm_params = model["litellm_params"]
|
||||||
|
model_name = litellm.utils.remove_model_id(litellm_params["model"])
|
||||||
|
|
||||||
|
if model_name in litellm.all_embedding_models:
|
||||||
|
continue # Skip embedding models
|
||||||
|
|
||||||
|
|
||||||
|
models_to_check.append(model)
|
||||||
|
|
||||||
|
|
||||||
|
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
|
||||||
|
|
||||||
|
return healthy_endpoints, unhealthy_endpoints
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@ import secrets, subprocess
|
||||||
import hashlib, uuid
|
import hashlib, uuid
|
||||||
import warnings
|
import warnings
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
from litellm.health_check import perform_health_check
|
||||||
messages: list = []
|
messages: list = []
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -1173,34 +1175,14 @@ async def test_endpoint(request: Request):
|
||||||
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"])
|
@router.get("/health", description="Check the health of all the endpoints in config.yaml", tags=["health"])
|
||||||
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
|
async def health_endpoint(request: Request, model: Optional[str] = fastapi.Query(None, description="Specify the model name (optional)")):
|
||||||
global llm_model_list
|
global llm_model_list
|
||||||
healthy_endpoints = []
|
|
||||||
unhealthy_endpoints = []
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(llm_model_list, model)
|
||||||
if llm_model_list:
|
|
||||||
for model_name in llm_model_list:
|
|
||||||
try:
|
|
||||||
if model is None or model == model_name["litellm_params"]["model"]: # if model specified, just call that one.
|
|
||||||
litellm_params = model_name["litellm_params"]
|
|
||||||
model_name = litellm.utils.remove_model_id(litellm_params["model"]) # removes, ids set by litellm.router
|
|
||||||
if model_name not in litellm.all_embedding_models: # filter out embedding models
|
|
||||||
litellm_params["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
|
||||||
litellm_params["model"] = model_name
|
|
||||||
litellm.completion(**litellm_params)
|
|
||||||
cleaned_params = {}
|
|
||||||
for key in litellm_params:
|
|
||||||
if key != "api_key" and key != "messages":
|
|
||||||
cleaned_params[key] = litellm_params[key]
|
|
||||||
healthy_endpoints.append(cleaned_params)
|
|
||||||
except Exception as e:
|
|
||||||
print("Got Exception", e)
|
|
||||||
cleaned_params = {}
|
|
||||||
for key in litellm_params:
|
|
||||||
if key != "api_key" and key != "messages":
|
|
||||||
cleaned_params[key] = litellm_params[key]
|
|
||||||
unhealthy_endpoints.append(cleaned_params)
|
|
||||||
pass
|
|
||||||
return {
|
return {
|
||||||
"healthy_endpoints": healthy_endpoints,
|
"healthy_endpoints": healthy_endpoints,
|
||||||
"unhealthy_endpoints": unhealthy_endpoints
|
"unhealthy_endpoints": unhealthy_endpoints,
|
||||||
|
"healthy_count": len(healthy_endpoints),
|
||||||
|
"unhealthy_count": len(unhealthy_endpoints),
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue