forked from phoenix/litellm-mirror
fix(proxy_server.py): add alerting for responses taking too long
https://github.com/BerriAI/litellm/issues/1298
This commit is contained in:
parent
0a6e4db999
commit
cd98d256b5
2 changed files with 95 additions and 10 deletions
|
@ -648,7 +648,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None))
|
proxy_logging_obj.update_values(
|
||||||
|
alerting=general_settings.get("alerting", None),
|
||||||
|
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||||
|
)
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
|
@ -927,6 +930,7 @@ def data_generator(response):
|
||||||
async def async_data_generator(response, user_api_key_dict):
|
async def async_data_generator(response, user_api_key_dict):
|
||||||
print_verbose("inside generator")
|
print_verbose("inside generator")
|
||||||
try:
|
try:
|
||||||
|
start_time = time.time()
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print_verbose(f"returned chunk: {chunk}")
|
print_verbose(f"returned chunk: {chunk}")
|
||||||
try:
|
try:
|
||||||
|
@ -934,6 +938,14 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"data: {str(e)}\n\n"
|
yield f"data: {str(e)}\n\n"
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Streaming is done, yield the [DONE] chunk
|
# Streaming is done, yield the [DONE] chunk
|
||||||
done_message = "[DONE]"
|
done_message = "[DONE]"
|
||||||
yield f"data: {done_message}\n\n"
|
yield f"data: {done_message}\n\n"
|
||||||
|
@ -1103,6 +1115,8 @@ async def completion(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
### ROUTE THE REQUESTs ###
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -1150,6 +1164,14 @@ async def completion(
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1254,6 +1276,8 @@ async def chat_completion(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
### ROUTE THE REQUEST ###
|
### ROUTE THE REQUEST ###
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -1289,6 +1313,7 @@ async def chat_completion(
|
||||||
model_id = response._hidden_params.get("model_id", None) or ""
|
model_id = response._hidden_params.get("model_id", None) or ""
|
||||||
else:
|
else:
|
||||||
model_id = ""
|
model_id = ""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
|
@ -1302,6 +1327,14 @@ async def chat_completion(
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1428,6 +1461,8 @@ async def embeddings(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
|
@ -1456,6 +1491,14 @@ async def embeddings(
|
||||||
else:
|
else:
|
||||||
response = await litellm.aembedding(**data)
|
response = await litellm.aembedding(**data)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
@ -1535,6 +1578,8 @@ async def image_generation(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
|
@ -1560,6 +1605,14 @@ async def image_generation(
|
||||||
else:
|
else:
|
||||||
response = await litellm.aimage_generation(**data)
|
response = await litellm.aimage_generation(**data)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
end_time = time.time()
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.response_taking_too_long(
|
||||||
|
start_time=start_time, end_time=end_time, type="slow_response"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
|
|
@ -34,10 +34,15 @@ class ProxyLogging:
|
||||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||||
self.max_budget_limiter = MaxBudgetLimiter()
|
self.max_budget_limiter = MaxBudgetLimiter()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def update_values(self, alerting: Optional[List]):
|
def update_values(
|
||||||
|
self, alerting: Optional[List], alerting_threshold: Optional[float]
|
||||||
|
):
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
|
if alerting_threshold is not None:
|
||||||
|
self.alerting_threshold = alerting_threshold
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
|
@ -105,18 +110,45 @@ class ProxyLogging:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def success_handler(self, *args, **kwargs):
|
async def success_handler(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response: Any,
|
||||||
|
call_type: Literal["completion", "embeddings"],
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Log successful db read/writes
|
Log successful API calls / db read/writes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def response_taking_too_long(self):
|
async def response_taking_too_long(
|
||||||
# Simulate a long-running operation that could take more than 5 minutes
|
self,
|
||||||
await asyncio.sleep(
|
start_time: Optional[float] = None,
|
||||||
300
|
end_time: Optional[float] = None,
|
||||||
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
||||||
await self.alerting_handler(message="Requests are hanging", level="Medium")
|
):
|
||||||
|
if type == "hanging_request":
|
||||||
|
# Simulate a long-running operation that could take more than 5 minutes
|
||||||
|
await asyncio.sleep(
|
||||||
|
self.alerting_threshold
|
||||||
|
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
||||||
|
|
||||||
|
await self.alerting_handler(
|
||||||
|
message=f"Requests are hanging - {self.alerting_threshold}s+ request time",
|
||||||
|
level="Medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
type == "slow_response" and start_time is not None and end_time is not None
|
||||||
|
):
|
||||||
|
if end_time - start_time > self.alerting_threshold:
|
||||||
|
await self.alerting_handler(
|
||||||
|
message=f"Responses are slow - {round(end_time-start_time,2)}s response time",
|
||||||
|
level="Low",
|
||||||
|
)
|
||||||
|
|
||||||
async def alerting_handler(
|
async def alerting_handler(
|
||||||
self, message: str, level: Literal["Low", "Medium", "High"]
|
self, message: str, level: Literal["Low", "Medium", "High"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue