forked from phoenix/litellm-mirror
feat(proxy/utils.py): return api base for request hanging alerts
This commit is contained in:
parent
b49e47b634
commit
6110d32b1c
7 changed files with 180 additions and 15 deletions
|
@ -595,6 +595,7 @@ from .utils import (
|
||||||
_should_retry,
|
_should_retry,
|
||||||
get_secret,
|
get_secret,
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
|
get_api_base,
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
|
@ -3441,13 +3441,15 @@ async def chat_completion(
|
||||||
) # run the moderation check in parallel to the actual llm api call
|
) # run the moderation check in parallel to the actual llm api call
|
||||||
response = responses[1]
|
response = responses[1]
|
||||||
|
|
||||||
# Post Call Processing
|
|
||||||
data["litellm_status"] = "success" # used for alerting
|
|
||||||
|
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
cache_key = hidden_params.get("cache_key", None) or ""
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
|
||||||
|
# Post Call Processing
|
||||||
|
if llm_router is not None:
|
||||||
|
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
|
|
|
@ -182,6 +182,25 @@ class ProxyLogging:
|
||||||
raise e
|
raise e
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _response_taking_too_long_callback(
|
||||||
|
self,
|
||||||
|
kwargs, # kwargs to completion
|
||||||
|
start_time,
|
||||||
|
end_time, # start/end time
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
time_difference = end_time - start_time
|
||||||
|
# Convert the timedelta to float (in seconds)
|
||||||
|
time_difference_float = time_difference.total_seconds()
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
api_base = litellm_params.get("api_base", "")
|
||||||
|
model = kwargs.get("model", "")
|
||||||
|
messages = kwargs.get("messages", "")
|
||||||
|
|
||||||
|
return time_difference_float, model, api_base, messages
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def response_taking_too_long_callback(
|
async def response_taking_too_long_callback(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
|
@ -191,13 +210,13 @@ class ProxyLogging:
|
||||||
):
|
):
|
||||||
if self.alerting is None:
|
if self.alerting is None:
|
||||||
return
|
return
|
||||||
time_difference = end_time - start_time
|
time_difference_float, model, api_base, messages = (
|
||||||
# Convert the timedelta to float (in seconds)
|
self._response_taking_too_long_callback(
|
||||||
time_difference_float = time_difference.total_seconds()
|
kwargs=kwargs,
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
start_time=start_time,
|
||||||
api_base = litellm_params.get("api_base", "")
|
end_time=end_time,
|
||||||
model = kwargs.get("model", "")
|
)
|
||||||
messages = kwargs.get("messages", "")
|
)
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||||
if time_difference_float > self.alerting_threshold:
|
if time_difference_float > self.alerting_threshold:
|
||||||
|
@ -244,6 +263,20 @@ class ProxyLogging:
|
||||||
request_data is not None
|
request_data is not None
|
||||||
and request_data.get("litellm_status", "") != "success"
|
and request_data.get("litellm_status", "") != "success"
|
||||||
):
|
):
|
||||||
|
if request_data.get("deployment", None) is not None and isinstance(
|
||||||
|
request_data["deployment"], dict
|
||||||
|
):
|
||||||
|
_api_base = litellm.get_api_base(
|
||||||
|
model=model,
|
||||||
|
optional_params=request_data["deployment"].get(
|
||||||
|
"litellm_params", {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if _api_base is None:
|
||||||
|
_api_base = ""
|
||||||
|
|
||||||
|
request_info += f"\nAPI Base: {_api_base}"
|
||||||
# only alert hanging responses if they have not been marked as success
|
# only alert hanging responses if they have not been marked as success
|
||||||
alerting_message = (
|
alerting_message = (
|
||||||
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
|
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
|
||||||
|
|
|
@ -77,6 +77,13 @@ class LiteLLM_Params(BaseModel):
|
||||||
)
|
)
|
||||||
max_retries: int = 2 # follows openai default of 2
|
max_retries: int = 2 # follows openai default of 2
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
|
## VERTEX AI ##
|
||||||
|
vertex_project: Optional[str] = None
|
||||||
|
vertex_location: Optional[str] = None
|
||||||
|
## AWS BEDROCK / SAGEMAKER ##
|
||||||
|
aws_access_key_id: Optional[str] = None
|
||||||
|
aws_secret_access_key: Optional[str] = None
|
||||||
|
aws_region_name: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
|
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
|
@ -2263,6 +2270,13 @@ class Router:
|
||||||
self.model_names.append(deployment.model_name)
|
self.model_names.append(deployment.model_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def get_deployment(self, model_id: str):
|
||||||
|
for model in self.model_list:
|
||||||
|
if "model_info" in model and "id" in model["model_info"]:
|
||||||
|
if model_id == model["model_info"]["id"]:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
def get_model_ids(self):
|
def get_model_ids(self):
|
||||||
ids = []
|
ids = []
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
|
|
86
litellm/tests/test_alerting.py
Normal file
86
litellm/tests/test_alerting.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# What is this?
|
||||||
|
## Tests slack alerting on proxy logging object
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import io, asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# import logging
|
||||||
|
# logging.basicConfig(level=logging.DEBUG)
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
import litellm
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_api_base():
|
||||||
|
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
_pl.update_values(alerting=["slack"], alerting_threshold=100, redis_cache=None)
|
||||||
|
model = "chatgpt-v-2"
|
||||||
|
messages = [{"role": "user", "content": "Hey how's it going?"}]
|
||||||
|
litellm_params = {
|
||||||
|
"acompletion": True,
|
||||||
|
"api_key": None,
|
||||||
|
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||||
|
"force_timeout": 600,
|
||||||
|
"logger_fn": None,
|
||||||
|
"verbose": False,
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"litellm_call_id": "68f46d2d-714d-4ad8-8137-69600ec8755c",
|
||||||
|
"model_alias_map": {},
|
||||||
|
"completion_call_id": None,
|
||||||
|
"metadata": None,
|
||||||
|
"model_info": None,
|
||||||
|
"proxy_server_request": None,
|
||||||
|
"preset_cache_key": None,
|
||||||
|
"no-log": False,
|
||||||
|
"stream_response": {},
|
||||||
|
}
|
||||||
|
start_time = datetime.now()
|
||||||
|
end_time = datetime.now()
|
||||||
|
|
||||||
|
time_difference_float, model, api_base, messages = (
|
||||||
|
_pl._response_taking_too_long_callback(
|
||||||
|
kwargs={
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
|
},
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert api_base is not None
|
||||||
|
assert isinstance(api_base, str)
|
||||||
|
assert len(api_base) > 0
|
||||||
|
request_info = (
|
||||||
|
f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
|
)
|
||||||
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {100}s`"
|
||||||
|
await _pl.alerting_handler(
|
||||||
|
message=slow_message + request_info,
|
||||||
|
level="Low",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_taking_too_long():
|
||||||
|
"""
|
||||||
|
- attach request_taking_too_long as a success callback to litellm
|
||||||
|
- unit test kwargs for azure call vs. vertex ai call -> ensure api base constructed correctly for both
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
litellm.success_callback = [_pl.response_taking_too_long_callback]
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
raise Exception("it worked!")
|
|
@ -75,7 +75,6 @@ class CompletionCustomHandler(
|
||||||
|
|
||||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
print(f"kwargs: {kwargs}")
|
|
||||||
self.states.append("post_api_call")
|
self.states.append("post_api_call")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
|
@ -167,6 +166,8 @@ class CompletionCustomHandler(
|
||||||
)
|
)
|
||||||
assert isinstance(kwargs["optional_params"], dict)
|
assert isinstance(kwargs["optional_params"], dict)
|
||||||
assert isinstance(kwargs["litellm_params"], dict)
|
assert isinstance(kwargs["litellm_params"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||||
|
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||||
assert isinstance(kwargs["stream"], bool)
|
assert isinstance(kwargs["stream"], bool)
|
||||||
assert isinstance(kwargs["user"], (str, type(None)))
|
assert isinstance(kwargs["user"], (str, type(None)))
|
||||||
|
@ -265,8 +266,10 @@ class CompletionCustomHandler(
|
||||||
assert isinstance(kwargs["messages"], list)
|
assert isinstance(kwargs["messages"], list)
|
||||||
assert isinstance(kwargs["optional_params"], dict)
|
assert isinstance(kwargs["optional_params"], dict)
|
||||||
assert isinstance(kwargs["litellm_params"], dict)
|
assert isinstance(kwargs["litellm_params"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||||
assert isinstance(kwargs["stream"], bool)
|
assert isinstance(kwargs["stream"], bool)
|
||||||
|
assert isinstance(kwargs["cache_hit"], Optional[bool])
|
||||||
assert isinstance(kwargs["user"], (str, type(None)))
|
assert isinstance(kwargs["user"], (str, type(None)))
|
||||||
assert isinstance(kwargs["input"], (list, dict, str))
|
assert isinstance(kwargs["input"], (list, dict, str))
|
||||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||||
|
@ -651,8 +654,8 @@ def load_vertex_ai_credentials():
|
||||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Vertex AI Hanging")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="Skipping on this PR to test other stuff")
|
|
||||||
async def test_async_chat_vertex_ai_stream():
|
async def test_async_chat_vertex_ai_stream():
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
|
@ -75,6 +75,7 @@ from .proxy._types import KeyManagementSystem
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
from .caching import S3Cache, RedisSemanticCache, RedisCache
|
from .caching import S3Cache, RedisSemanticCache, RedisCache
|
||||||
|
from .router import LiteLLM_Params
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
|
@ -1075,6 +1076,9 @@ class Logging:
|
||||||
headers = {}
|
headers = {}
|
||||||
data = additional_args.get("complete_input_dict", {})
|
data = additional_args.get("complete_input_dict", {})
|
||||||
api_base = additional_args.get("api_base", "")
|
api_base = additional_args.get("api_base", "")
|
||||||
|
self.model_call_details["litellm_params"]["api_base"] = str(
|
||||||
|
api_base
|
||||||
|
) # used for alerting
|
||||||
masked_headers = {
|
masked_headers = {
|
||||||
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
|
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
|
||||||
for k, v in headers.items()
|
for k, v in headers.items()
|
||||||
|
@ -1203,7 +1207,6 @@ class Logging:
|
||||||
self.model_call_details["original_response"] = original_response
|
self.model_call_details["original_response"] = original_response
|
||||||
self.model_call_details["additional_args"] = additional_args
|
self.model_call_details["additional_args"] = additional_args
|
||||||
self.model_call_details["log_event_type"] = "post_api_call"
|
self.model_call_details["log_event_type"] = "post_api_call"
|
||||||
|
|
||||||
# User Logging -> if you pass in a custom logging function
|
# User Logging -> if you pass in a custom logging function
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
|
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
|
||||||
|
@ -2546,7 +2549,7 @@ def client(original_function):
|
||||||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||||
)
|
)
|
||||||
## check if metadata is passed in
|
## check if metadata is passed in
|
||||||
litellm_params = {}
|
litellm_params = {"api_base": ""}
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
litellm_params["metadata"] = kwargs["metadata"]
|
litellm_params["metadata"] = kwargs["metadata"]
|
||||||
logging_obj.update_environment_variables(
|
logging_obj.update_environment_variables(
|
||||||
|
@ -3033,7 +3036,7 @@ def client(original_function):
|
||||||
cached_result = await litellm.cache.async_get_cache(
|
cached_result = await litellm.cache.async_get_cache(
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
kwargs["preset_cache_key"] = (
|
kwargs["preset_cache_key"] = (
|
||||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
|
@ -3076,6 +3079,7 @@ def client(original_function):
|
||||||
"preset_cache_key", None
|
"preset_cache_key", None
|
||||||
),
|
),
|
||||||
"stream_response": kwargs.get("stream_response", {}),
|
"stream_response": kwargs.get("stream_response", {}),
|
||||||
|
"api_base": kwargs.get("api_base", ""),
|
||||||
},
|
},
|
||||||
input=kwargs.get("messages", ""),
|
input=kwargs.get("messages", ""),
|
||||||
api_key=kwargs.get("api_key", None),
|
api_key=kwargs.get("api_key", None),
|
||||||
|
@ -3209,6 +3213,7 @@ def client(original_function):
|
||||||
"stream_response": kwargs.get(
|
"stream_response": kwargs.get(
|
||||||
"stream_response", {}
|
"stream_response", {}
|
||||||
),
|
),
|
||||||
|
"api_base": "",
|
||||||
},
|
},
|
||||||
input=kwargs.get("messages", ""),
|
input=kwargs.get("messages", ""),
|
||||||
api_key=kwargs.get("api_key", None),
|
api_key=kwargs.get("api_key", None),
|
||||||
|
@ -5305,6 +5310,27 @@ def get_optional_params(
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
|
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
|
||||||
|
|
||||||
|
if _optional_params.api_base is not None:
|
||||||
|
return _optional_params.api_base
|
||||||
|
|
||||||
|
if (
|
||||||
|
_optional_params.vertex_location is not None
|
||||||
|
and _optional_params.vertex_project is not None
|
||||||
|
):
|
||||||
|
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||||
|
_optional_params.vertex_location,
|
||||||
|
_optional_params.vertex_project,
|
||||||
|
_optional_params.vertex_location,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
return _api_base
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_supported_openai_params(model: str, custom_llm_provider: str):
|
def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"""
|
"""
|
||||||
Returns the supported openai params for a given model + provider
|
Returns the supported openai params for a given model + provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue