feat(proxy/utils.py): return api base for request hanging alerts

This commit is contained in:
Krrish Dholakia 2024-04-06 15:58:53 -07:00
parent b49e47b634
commit 6110d32b1c
7 changed files with 180 additions and 15 deletions

View file

@ -595,6 +595,7 @@ from .utils import (
_should_retry,
get_secret,
get_supported_openai_params,
get_api_base,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig

View file

@ -3441,13 +3441,15 @@ async def chat_completion(
) # run the moderation check in parallel to the actual llm api call
response = responses[1]
# Post Call Processing
data["litellm_status"] = "success" # used for alerting
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
# Post Call Processing
if llm_router is not None:
data["deployment"] = llm_router.get_deployment(model_id=model_id)
data["litellm_status"] = "success" # used for alerting
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses

View file

@ -182,6 +182,25 @@ class ProxyLogging:
raise e
return data
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
model = kwargs.get("model", "")
messages = kwargs.get("messages", "")
return time_difference_float, model, api_base, messages
except Exception as e:
raise e
async def response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
@ -191,13 +210,13 @@ class ProxyLogging:
):
if self.alerting is None:
return
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
model = kwargs.get("model", "")
messages = kwargs.get("messages", "")
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
)
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold:
@ -244,6 +263,20 @@ class ProxyLogging:
request_data is not None
and request_data.get("litellm_status", "") != "success"
):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
# only alert hanging responses if they have not been marked as success
alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"

View file

@ -77,6 +77,13 @@ class LiteLLM_Params(BaseModel):
)
max_retries: int = 2 # follows openai default of 2
organization: Optional[str] = None # for openai orgs
## VERTEX AI ##
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
if max_retries is None:
@ -2263,6 +2270,13 @@ class Router:
self.model_names.append(deployment.model_name)
return
def get_deployment(self, model_id: str):
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
if model_id == model["model_info"]["id"]:
return model
return None
def get_model_ids(self):
ids = []
for model in self.model_list:

View file

@ -0,0 +1,86 @@
# What is this?
## Tests slack alerting on proxy logging object
import sys
import os
import io, asyncio
from datetime import datetime
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath("../.."))
from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache
import litellm
import pytest
@pytest.mark.asyncio
async def test_get_api_base():
_pl = ProxyLogging(user_api_key_cache=DualCache())
_pl.update_values(alerting=["slack"], alerting_threshold=100, redis_cache=None)
model = "chatgpt-v-2"
messages = [{"role": "user", "content": "Hey how's it going?"}]
litellm_params = {
"acompletion": True,
"api_key": None,
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "azure",
"litellm_call_id": "68f46d2d-714d-4ad8-8137-69600ec8755c",
"model_alias_map": {},
"completion_call_id": None,
"metadata": None,
"model_info": None,
"proxy_server_request": None,
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
}
start_time = datetime.now()
end_time = datetime.now()
time_difference_float, model, api_base, messages = (
_pl._response_taking_too_long_callback(
kwargs={
"model": model,
"messages": messages,
"litellm_params": litellm_params,
},
start_time=start_time,
end_time=end_time,
)
)
assert api_base is not None
assert isinstance(api_base, str)
assert len(api_base) > 0
request_info = (
f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
)
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {100}s`"
await _pl.alerting_handler(
message=slow_message + request_info,
level="Low",
)
@pytest.mark.asyncio
async def test_request_taking_too_long():
"""
- attach request_taking_too_long as a success callback to litellm
- unit test kwargs for azure call vs. vertex ai call -> ensure api base constructed correctly for both
"""
import time
_pl = ProxyLogging(user_api_key_cache=DualCache())
litellm.success_callback = [_pl.response_taking_too_long_callback]
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
raise Exception("it worked!")

View file

@ -75,7 +75,6 @@ class CompletionCustomHandler(
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
print(f"kwargs: {kwargs}")
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
@ -167,6 +166,8 @@ class CompletionCustomHandler(
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["cache_hit"], Optional[bool])
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
@ -265,8 +266,10 @@ class CompletionCustomHandler(
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["cache_hit"], Optional[bool])
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
@ -651,8 +654,8 @@ def load_vertex_ai_credentials():
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.skip(reason="Vertex AI Hanging")
@pytest.mark.asyncio
@pytest.mark.skip(reason="Skipping on this PR to test other stuff")
async def test_async_chat_vertex_ai_stream():
try:
load_vertex_ai_credentials()

View file

@ -75,6 +75,7 @@ from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache, RedisSemanticCache, RedisCache
from .router import LiteLLM_Params
from .exceptions import (
AuthenticationError,
BadRequestError,
@ -1075,6 +1076,9 @@ class Logging:
headers = {}
data = additional_args.get("complete_input_dict", {})
api_base = additional_args.get("api_base", "")
self.model_call_details["litellm_params"]["api_base"] = str(
api_base
) # used for alerting
masked_headers = {
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
for k, v in headers.items()
@ -1203,7 +1207,6 @@ class Logging:
self.model_call_details["original_response"] = original_response
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "post_api_call"
# User Logging -> if you pass in a custom logging function
print_verbose(
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
@ -2546,7 +2549,7 @@ def client(original_function):
langfuse_secret=kwargs.pop("langfuse_secret", None),
)
## check if metadata is passed in
litellm_params = {}
litellm_params = {"api_base": ""}
if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(
@ -3033,7 +3036,7 @@ def client(original_function):
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
else:
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
@ -3076,6 +3079,7 @@ def client(original_function):
"preset_cache_key", None
),
"stream_response": kwargs.get("stream_response", {}),
"api_base": kwargs.get("api_base", ""),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
@ -3209,6 +3213,7 @@ def client(original_function):
"stream_response": kwargs.get(
"stream_response", {}
),
"api_base": "",
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
@ -5305,6 +5310,27 @@ def get_optional_params(
return optional_params
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
if _optional_params.api_base is not None:
return _optional_params.api_base
if (
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
_optional_params.vertex_location,
_optional_params.vertex_project,
_optional_params.vertex_location,
model,
)
return _api_base
return None
def get_supported_openai_params(model: str, custom_llm_provider: str):
"""
Returns the supported openai params for a given model + provider