fix(proxy_server.py): add alerting for responses taking too long

https://github.com/BerriAI/litellm/issues/1298
This commit is contained in:
Krrish Dholakia 2024-01-03 11:18:21 +05:30
parent 0a6e4db999
commit cd98d256b5
2 changed files with 95 additions and 10 deletions

View file

@ -648,7 +648,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(alerting=general_settings.get("alerting", None))
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
@ -927,6 +930,7 @@ def data_generator(response):
async def async_data_generator(response, user_api_key_dict):
print_verbose("inside generator")
try:
start_time = time.time()
async for chunk in response:
print_verbose(f"returned chunk: {chunk}")
try:
@ -934,6 +938,14 @@ async def async_data_generator(response, user_api_key_dict):
except Exception as e:
yield f"data: {str(e)}\n\n"
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
# Streaming is done, yield the [DONE] chunk
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
@ -1103,6 +1115,8 @@ async def completion(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
start_time = time.time()
### ROUTE THE REQUESTs ###
router_model_names = (
[m["model_name"] for m in llm_model_list]
@ -1150,6 +1164,14 @@ async def completion(
headers=custom_headers,
)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
@ -1254,6 +1276,8 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
start_time = time.time()
### ROUTE THE REQUEST ###
router_model_names = (
[m["model_name"] for m in llm_model_list]
@ -1289,6 +1313,7 @@ async def chat_completion(
model_id = response._hidden_params.get("model_id", None) or ""
else:
model_id = ""
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
@ -1302,6 +1327,14 @@ async def chat_completion(
headers=custom_headers,
)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
fastapi_response.headers["x-litellm-model-id"] = model_id
return response
except Exception as e:
@ -1428,6 +1461,8 @@ async def embeddings(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
@ -1456,6 +1491,14 @@ async def embeddings(
else:
response = await litellm.aembedding(**data)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
@ -1535,6 +1578,8 @@ async def image_generation(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
@ -1560,6 +1605,14 @@ async def image_generation(
else:
response = await litellm.aimage_generation(**data)
### ALERTING ###
end_time = time.time()
asyncio.create_task(
proxy_logging_obj.response_taking_too_long(
start_time=start_time, end_time=end_time, type="slow_response"
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(