Merge branch 'main' into litellm_azure_content_filter_fallbacks

This commit is contained in:
Krish Dholakia 2024-06-22 21:28:29 -07:00 committed by GitHub
commit 39c2fe511c
51 changed files with 1650 additions and 1074 deletions

View file

@ -11,6 +11,7 @@ import asyncio
import concurrent
import copy
import datetime as datetime_og
import enum
import hashlib
import inspect
import json
@ -90,6 +91,10 @@ from litellm.utils import (
)
class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key)
class Router:
model_names: List = []
cache_responses: Optional[bool] = False
@ -387,6 +392,11 @@ class Router:
routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args,
)
## USAGE TRACKING ##
if isinstance(litellm._async_success_callback, list):
litellm._async_success_callback.append(self.deployment_callback_on_success)
else:
litellm._async_success_callback.append(self.deployment_callback_on_success)
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
@ -2664,13 +2674,69 @@ class Router:
time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
setattr(original_exception, "max_retries", num_retries)
setattr(original_exception, "num_retries", current_attempt)
raise original_exception
### HELPER FUNCTIONS
async def deployment_callback_on_success(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
"""
Track remaining tpm/rpm quota for model in model_list
"""
try:
"""
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = completion_response["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"global_router:{id}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
await self.cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
)
except Exception as e:
verbose_router_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
pass
def deployment_callback_on_failure(
self,
kwargs, # kwargs to completion
@ -3870,10 +3936,39 @@ class Router:
model_group_info: Optional[ModelGroupInfo] = None
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model tpm
_deployment_tpm: Optional[int] = None
if _deployment_tpm is None:
_deployment_tpm = model.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
if _deployment_tpm is not None:
if total_tpm is None:
total_tpm = 0
total_tpm += _deployment_tpm # type: ignore
# get model rpm
_deployment_rpm: Optional[int] = None
if _deployment_rpm is None:
_deployment_rpm = model.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
if _deployment_rpm is not None:
if total_rpm is None:
total_rpm = 0
total_rpm += _deployment_rpm # type: ignore
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
@ -3987,8 +4082,44 @@ class Router:
"supported_openai_params"
]
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None and model_group_info is not None:
model_group_info.tpm = total_tpm
if total_rpm is not None and model_group_info is not None:
model_group_info.rpm = total_rpm
return model_group_info
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
"""
Returns remaining tpm quota for model group
"""
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_keys: List[str] = []
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
)
## TPM
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
keys=tpm_keys
)
tpm_usage: Optional[int] = None
if tpm_usage_list is not None:
for t in tpm_usage_list:
if isinstance(t, int):
if tpm_usage is None:
tpm_usage = 0
tpm_usage += t
return tpm_usage
def get_model_ids(self) -> List[str]:
"""
Returns list of model id's.
@ -4916,7 +5047,7 @@ class Router:
def reset(self):
## clean up on close
litellm.success_callback = []
litellm.__async_success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm._async_failure_callback = []
self.retry_policy = None