Merge pull request #2880 from BerriAI/litellm_api_base_alerting

feat(proxy/utils.py): return api base for request hanging alerts
This commit is contained in:
Krish Dholakia 2024-04-06 19:17:18 -07:00 committed by GitHub
commit 83f608dc5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 320 additions and 126 deletions

View file

@ -595,6 +595,7 @@ from .utils import (
_should_retry, _should_retry,
get_secret, get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig

View file

@ -3441,13 +3441,15 @@ async def chat_completion(
) # run the moderation check in parallel to the actual llm api call ) # run the moderation check in parallel to the actual llm api call
response = responses[1] response = responses[1]
# Post Call Processing
data["litellm_status"] = "success" # used for alerting
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
# Post Call Processing
if llm_router is not None:
data["deployment"] = llm_router.get_deployment(model_id=model_id)
data["litellm_status"] = "success" # used for alerting
if ( if (
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses

View file

@ -182,6 +182,25 @@ class ProxyLogging:
raise e raise e
return data return data
def _response_taking_too_long_callback(
self,
kwargs, # kwargs to completion
start_time,
end_time, # start/end time
):
try:
time_difference = end_time - start_time
# Convert the timedelta to float (in seconds)
time_difference_float = time_difference.total_seconds()
litellm_params = kwargs.get("litellm_params", {})
api_base = litellm_params.get("api_base", "")
model = kwargs.get("model", "")
messages = kwargs.get("messages", "")
return time_difference_float, model, api_base, messages
except Exception as e:
raise e
async def response_taking_too_long_callback( async def response_taking_too_long_callback(
self, self,
kwargs, # kwargs to completion kwargs, # kwargs to completion
@ -191,13 +210,13 @@ class ProxyLogging:
): ):
if self.alerting is None: if self.alerting is None:
return return
time_difference = end_time - start_time time_difference_float, model, api_base, messages = (
# Convert the timedelta to float (in seconds) self._response_taking_too_long_callback(
time_difference_float = time_difference.total_seconds() kwargs=kwargs,
litellm_params = kwargs.get("litellm_params", {}) start_time=start_time,
api_base = litellm_params.get("api_base", "") end_time=end_time,
model = kwargs.get("model", "") )
messages = kwargs.get("messages", "") )
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
@ -244,6 +263,20 @@ class ProxyLogging:
request_data is not None request_data is not None
and request_data.get("litellm_status", "") != "success" and request_data.get("litellm_status", "") != "success"
): ):
if request_data.get("deployment", None) is not None and isinstance(
request_data["deployment"], dict
):
_api_base = litellm.get_api_base(
model=model,
optional_params=request_data["deployment"].get(
"litellm_params", {}
),
)
if _api_base is None:
_api_base = ""
request_info += f"\nAPI Base: {_api_base}"
# only alert hanging responses if they have not been marked as success # only alert hanging responses if they have not been marked as success
alerting_message = ( alerting_message = (
f"`Requests are hanging - {self.alerting_threshold}s+ request time`" f"`Requests are hanging - {self.alerting_threshold}s+ request time`"

View file

@ -29,117 +29,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
import copy import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import logging import logging
from pydantic import BaseModel, validator from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None:
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
elif isinstance(id, int):
id = str(id)
super().__init__(id=id, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class LiteLLM_Params(BaseModel):
model: str
tpm: Optional[int] = None
rpm: Optional[int] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/
)
max_retries: int = 2 # follows openai default of 2
organization: Optional[str] = None # for openai orgs
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
if max_retries is None:
max_retries = 2
elif isinstance(max_retries, str):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Deployment(BaseModel):
model_name: str
litellm_params: LiteLLM_Params
model_info: ModelInfo
def __init__(self, model_info: Optional[ModelInfo] = None, **params):
if model_info is None:
model_info = ModelInfo()
super().__init__(model_info=model_info, **params)
def to_json(self, **kwargs):
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Router: class Router:
@ -2283,6 +2173,13 @@ class Router:
self.model_names.append(deployment.model_name) self.model_names.append(deployment.model_name)
return return
def get_deployment(self, model_id: str):
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
if model_id == model["model_info"]["id"]:
return model
return None
def get_model_ids(self): def get_model_ids(self):
ids = [] ids = []
for model in self.model_list: for model in self.model_list:

View file

@ -0,0 +1,86 @@
# What is this?
## Tests slack alerting on proxy logging object
import sys
import os
import io, asyncio
from datetime import datetime
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath("../.."))
from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache
import litellm
import pytest
@pytest.mark.asyncio
async def test_get_api_base():
_pl = ProxyLogging(user_api_key_cache=DualCache())
_pl.update_values(alerting=["slack"], alerting_threshold=100, redis_cache=None)
model = "chatgpt-v-2"
messages = [{"role": "user", "content": "Hey how's it going?"}]
litellm_params = {
"acompletion": True,
"api_key": None,
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "azure",
"litellm_call_id": "68f46d2d-714d-4ad8-8137-69600ec8755c",
"model_alias_map": {},
"completion_call_id": None,
"metadata": None,
"model_info": None,
"proxy_server_request": None,
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
}
start_time = datetime.now()
end_time = datetime.now()
time_difference_float, model, api_base, messages = (
_pl._response_taking_too_long_callback(
kwargs={
"model": model,
"messages": messages,
"litellm_params": litellm_params,
},
start_time=start_time,
end_time=end_time,
)
)
assert api_base is not None
assert isinstance(api_base, str)
assert len(api_base) > 0
request_info = (
f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
)
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {100}s`"
await _pl.alerting_handler(
message=slow_message + request_info,
level="Low",
)
@pytest.mark.asyncio
async def test_request_taking_too_long():
"""
- attach request_taking_too_long as a success callback to litellm
- unit test kwargs for azure call vs. vertex ai call -> ensure api base constructed correctly for both
"""
import time
_pl = ProxyLogging(user_api_key_cache=DualCache())
litellm.success_callback = [_pl.response_taking_too_long_callback]
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
raise Exception("it worked!")

View file

@ -75,7 +75,6 @@ class CompletionCustomHandler(
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try: try:
print(f"kwargs: {kwargs}")
self.states.append("post_api_call") self.states.append("post_api_call")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
@ -167,6 +166,8 @@ class CompletionCustomHandler(
) )
assert isinstance(kwargs["optional_params"], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["start_time"], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
@ -265,8 +266,10 @@ class CompletionCustomHandler(
assert isinstance(kwargs["messages"], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["start_time"], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool) assert isinstance(kwargs["stream"], bool)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["user"], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
@ -651,8 +654,8 @@ def load_vertex_ai_credentials():
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.skip(reason="Vertex AI Hanging")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="Skipping on this PR to test other stuff")
async def test_async_chat_vertex_ai_stream(): async def test_async_chat_vertex_ai_stream():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()

View file

@ -3,6 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
import uuid
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -39,3 +40,120 @@ class RouterConfig(BaseModel):
"usage-based-routing", "usage-based-routing",
"latency-based-routing", "latency-based-routing",
] = "simple-shuffle" ] = "simple-shuffle"
class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None:
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
elif isinstance(id, int):
id = str(id)
super().__init__(id=id, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class LiteLLM_Params(BaseModel):
model: str
tpm: Optional[int] = None
rpm: Optional[int] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/
)
max_retries: int = 2 # follows openai default of 2
organization: Optional[str] = None # for openai orgs
## VERTEX AI ##
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None
def __init__(self, max_retries: Optional[Union[int, str]] = None, **params):
if max_retries is None:
max_retries = 2
elif isinstance(max_retries, str):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Deployment(BaseModel):
model_name: str
litellm_params: LiteLLM_Params
model_info: ModelInfo
def __init__(self, model_info: Optional[ModelInfo] = None, **params):
if model_info is None:
model_info = ModelInfo()
super().__init__(model_info=model_info, **params)
def to_json(self, **kwargs):
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)

View file

@ -53,6 +53,7 @@ os.environ["TIKTOKEN_CACHE_DIR"] = (
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata import importlib.metadata
from ._logging import verbose_logger from ._logging import verbose_logger
from .types.router import LiteLLM_Params
from .integrations.traceloop import TraceloopLogger from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger from .integrations.helicone import HeliconeLogger
@ -1075,6 +1076,9 @@ class Logging:
headers = {} headers = {}
data = additional_args.get("complete_input_dict", {}) data = additional_args.get("complete_input_dict", {})
api_base = additional_args.get("api_base", "") api_base = additional_args.get("api_base", "")
self.model_call_details["litellm_params"]["api_base"] = str(
api_base
) # used for alerting
masked_headers = { masked_headers = {
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
for k, v in headers.items() for k, v in headers.items()
@ -1203,7 +1207,6 @@ class Logging:
self.model_call_details["original_response"] = original_response self.model_call_details["original_response"] = original_response
self.model_call_details["additional_args"] = additional_args self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "post_api_call" self.model_call_details["log_event_type"] = "post_api_call"
# User Logging -> if you pass in a custom logging function # User Logging -> if you pass in a custom logging function
print_verbose( print_verbose(
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n", f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
@ -2546,7 +2549,7 @@ def client(original_function):
langfuse_secret=kwargs.pop("langfuse_secret", None), langfuse_secret=kwargs.pop("langfuse_secret", None),
) )
## check if metadata is passed in ## check if metadata is passed in
litellm_params = {} litellm_params = {"api_base": ""}
if "metadata" in kwargs: if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"] litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables( logging_obj.update_environment_variables(
@ -3033,7 +3036,7 @@ def client(original_function):
cached_result = await litellm.cache.async_get_cache( cached_result = await litellm.cache.async_get_cache(
*args, **kwargs *args, **kwargs
) )
else: else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = ( kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key preset_cache_key # for streaming calls, we need to pass the preset_cache_key
@ -3076,6 +3079,7 @@ def client(original_function):
"preset_cache_key", None "preset_cache_key", None
), ),
"stream_response": kwargs.get("stream_response", {}), "stream_response": kwargs.get("stream_response", {}),
"api_base": kwargs.get("api_base", ""),
}, },
input=kwargs.get("messages", ""), input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None), api_key=kwargs.get("api_key", None),
@ -3209,6 +3213,7 @@ def client(original_function):
"stream_response": kwargs.get( "stream_response": kwargs.get(
"stream_response", {} "stream_response", {}
), ),
"api_base": "",
}, },
input=kwargs.get("messages", ""), input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None), api_key=kwargs.get("api_key", None),
@ -5305,6 +5310,55 @@ def get_optional_params(
return optional_params return optional_params
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
"""
Returns the api base used for calling the model.
Parameters:
- model: str - the model passed to litellm.completion()
- optional_params - the additional params passed to litellm.completion - eg. api_base, api_key, etc. See `LiteLLM_Params` - https://github.com/BerriAI/litellm/blob/f09e6ba98d65e035a79f73bc069145002ceafd36/litellm/router.py#L67
Returns:
- string (api_base) or None
Example:
```
from litellm import get_api_base
get_api_base(model="gemini/gemini-pro")
```
"""
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
# get llm provider
try:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model
)
except:
custom_llm_provider = None
if _optional_params.api_base is not None:
return _optional_params.api_base
if (
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
_optional_params.vertex_location,
_optional_params.vertex_project,
_optional_params.vertex_location,
model,
)
return _api_base
if custom_llm_provider is not None and custom_llm_provider == "gemini":
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
model
)
return _api_base
return None
def get_supported_openai_params(model: str, custom_llm_provider: str): def get_supported_openai_params(model: str, custom_llm_provider: str):
""" """
Returns the supported openai params for a given model + provider Returns the supported openai params for a given model + provider