forked from phoenix/litellm-mirror
feat(proxy/utils.py): return api base for request hanging alerts
This commit is contained in:
parent
b49e47b634
commit
6110d32b1c
7 changed files with 180 additions and 15 deletions
|
@ -595,6 +595,7 @@ from .utils import (
|
|||
_should_retry,
|
||||
get_secret,
|
||||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -3441,13 +3441,15 @@ async def chat_completion(
|
|||
) # run the moderation check in parallel to the actual llm api call
|
||||
response = responses[1]
|
||||
|
||||
# Post Call Processing
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
|
||||
# Post Call Processing
|
||||
if llm_router is not None:
|
||||
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
|
|
|
@ -182,6 +182,25 @@ class ProxyLogging:
|
|||
raise e
|
||||
return data
|
||||
|
||||
def _response_taking_too_long_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
try:
|
||||
time_difference = end_time - start_time
|
||||
# Convert the timedelta to float (in seconds)
|
||||
time_difference_float = time_difference.total_seconds()
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
model = kwargs.get("model", "")
|
||||
messages = kwargs.get("messages", "")
|
||||
|
||||
return time_difference_float, model, api_base, messages
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def response_taking_too_long_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
|
@ -191,13 +210,13 @@ class ProxyLogging:
|
|||
):
|
||||
if self.alerting is None:
|
||||
return
|
||||
time_difference = end_time - start_time
|
||||
# Convert the timedelta to float (in seconds)
|
||||
time_difference_float = time_difference.total_seconds()
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
model = kwargs.get("model", "")
|
||||
messages = kwargs.get("messages", "")
|
||||
time_difference_float, model, api_base, messages = (
|
||||
self._response_taking_too_long_callback(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||
if time_difference_float > self.alerting_threshold:
|
||||
|
@ -244,6 +263,20 @@ class ProxyLogging:
|
|||
request_data is not None
|
||||
and request_data.get("litellm_status", "") != "success"
|
||||
):
|
||||
if request_data.get("deployment", None) is not None and isinstance(
|
||||
request_data["deployment"], dict
|
||||
):
|
||||
_api_base = litellm.get_api_base(
|
||||
model=model,
|
||||
optional_params=request_data["deployment"].get(
|
||||
"litellm_params", {}
|
||||
),
|
||||
)
|
||||
|
||||
if _api_base is None:
|
||||
_api_base = ""
|
||||
|
||||
request_info += f"\nAPI Base: {_api_base}"
|
||||
# only alert hanging responses if they have not been marked as success
|
||||
alerting_message = (
|
||||
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
|
||||
|
|
|
@ -77,6 +77,13 @@ class LiteLLM_Params(BaseModel):
|
|||
)
|
||||
max_retries: int = 2 # follows openai default of 2
|
||||
organization: Optional[str] = None # for openai orgs
|
||||
## VERTEX AI ##
|
||||
vertex_project: Optional[str] = None
|
||||
vertex_location: Optional[str] = None
|
||||
## AWS BEDROCK / SAGEMAKER ##
|
||||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
aws_region_name: Optional[str] = None
|
||||
|
||||
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
|
||||
if max_retries is None:
|
||||
|
@ -2263,6 +2270,13 @@ class Router:
|
|||
self.model_names.append(deployment.model_name)
|
||||
return
|
||||
|
||||
def get_deployment(self, model_id: str):
|
||||
for model in self.model_list:
|
||||
if "model_info" in model and "id" in model["model_info"]:
|
||||
if model_id == model["model_info"]["id"]:
|
||||
return model
|
||||
return None
|
||||
|
||||
def get_model_ids(self):
|
||||
ids = []
|
||||
for model in self.model_list:
|
||||
|
|
86
litellm/tests/test_alerting.py
Normal file
86
litellm/tests/test_alerting.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
# What is this?
|
||||
## Tests slack alerting on proxy logging object
|
||||
|
||||
import sys
|
||||
import os
|
||||
import io, asyncio
|
||||
from datetime import datetime
|
||||
|
||||
# import logging
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching import DualCache
|
||||
import litellm
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_api_base():
|
||||
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||
_pl.update_values(alerting=["slack"], alerting_threshold=100, redis_cache=None)
|
||||
model = "chatgpt-v-2"
|
||||
messages = [{"role": "user", "content": "Hey how's it going?"}]
|
||||
litellm_params = {
|
||||
"acompletion": True,
|
||||
"api_key": None,
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"force_timeout": 600,
|
||||
"logger_fn": None,
|
||||
"verbose": False,
|
||||
"custom_llm_provider": "azure",
|
||||
"litellm_call_id": "68f46d2d-714d-4ad8-8137-69600ec8755c",
|
||||
"model_alias_map": {},
|
||||
"completion_call_id": None,
|
||||
"metadata": None,
|
||||
"model_info": None,
|
||||
"proxy_server_request": None,
|
||||
"preset_cache_key": None,
|
||||
"no-log": False,
|
||||
"stream_response": {},
|
||||
}
|
||||
start_time = datetime.now()
|
||||
end_time = datetime.now()
|
||||
|
||||
time_difference_float, model, api_base, messages = (
|
||||
_pl._response_taking_too_long_callback(
|
||||
kwargs={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"litellm_params": litellm_params,
|
||||
},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
assert api_base is not None
|
||||
assert isinstance(api_base, str)
|
||||
assert len(api_base) > 0
|
||||
request_info = (
|
||||
f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
)
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {100}s`"
|
||||
await _pl.alerting_handler(
|
||||
message=slow_message + request_info,
|
||||
level="Low",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_taking_too_long():
|
||||
"""
|
||||
- attach request_taking_too_long as a success callback to litellm
|
||||
- unit test kwargs for azure call vs. vertex ai call -> ensure api base constructed correctly for both
|
||||
"""
|
||||
import time
|
||||
|
||||
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||
litellm.success_callback = [_pl.response_taking_too_long_callback]
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
raise Exception("it worked!")
|
|
@ -75,7 +75,6 @@ class CompletionCustomHandler(
|
|||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"kwargs: {kwargs}")
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
|
@ -167,6 +166,8 @@ class CompletionCustomHandler(
|
|||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
|
@ -265,8 +266,10 @@ class CompletionCustomHandler(
|
|||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
|
@ -651,8 +654,8 @@ def load_vertex_ai_credentials():
|
|||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Vertex AI Hanging")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Skipping on this PR to test other stuff")
|
||||
async def test_async_chat_vertex_ai_stream():
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
|
|
|
@ -75,6 +75,7 @@ from .proxy._types import KeyManagementSystem
|
|||
from openai import OpenAIError as OriginalError
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from .caching import S3Cache, RedisSemanticCache, RedisCache
|
||||
from .router import LiteLLM_Params
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
|
@ -1075,6 +1076,9 @@ class Logging:
|
|||
headers = {}
|
||||
data = additional_args.get("complete_input_dict", {})
|
||||
api_base = additional_args.get("api_base", "")
|
||||
self.model_call_details["litellm_params"]["api_base"] = str(
|
||||
api_base
|
||||
) # used for alerting
|
||||
masked_headers = {
|
||||
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
|
||||
for k, v in headers.items()
|
||||
|
@ -1203,7 +1207,6 @@ class Logging:
|
|||
self.model_call_details["original_response"] = original_response
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
self.model_call_details["log_event_type"] = "post_api_call"
|
||||
|
||||
# User Logging -> if you pass in a custom logging function
|
||||
print_verbose(
|
||||
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
|
||||
|
@ -2546,7 +2549,7 @@ def client(original_function):
|
|||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||
)
|
||||
## check if metadata is passed in
|
||||
litellm_params = {}
|
||||
litellm_params = {"api_base": ""}
|
||||
if "metadata" in kwargs:
|
||||
litellm_params["metadata"] = kwargs["metadata"]
|
||||
logging_obj.update_environment_variables(
|
||||
|
@ -3033,7 +3036,7 @@ def client(original_function):
|
|||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
else:
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
|
@ -3076,6 +3079,7 @@ def client(original_function):
|
|||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
"api_base": kwargs.get("api_base", ""),
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
|
@ -3209,6 +3213,7 @@ def client(original_function):
|
|||
"stream_response": kwargs.get(
|
||||
"stream_response", {}
|
||||
),
|
||||
"api_base": "",
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
|
@ -5305,6 +5310,27 @@ def get_optional_params(
|
|||
return optional_params
|
||||
|
||||
|
||||
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
|
||||
|
||||
if _optional_params.api_base is not None:
|
||||
return _optional_params.api_base
|
||||
|
||||
if (
|
||||
_optional_params.vertex_location is not None
|
||||
and _optional_params.vertex_project is not None
|
||||
):
|
||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||
_optional_params.vertex_location,
|
||||
_optional_params.vertex_project,
|
||||
_optional_params.vertex_location,
|
||||
model,
|
||||
)
|
||||
return _api_base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||
"""
|
||||
Returns the supported openai params for a given model + provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue