Merge branch 'main' into litellm_slack_daily_reports

This commit is contained in:
Krish Dholakia 2024-05-06 19:31:20 -07:00 committed by GitHub
commit aa62d891a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 346 additions and 73 deletions

View file

@ -221,6 +221,12 @@ class ProxyException(Exception):
}
class UserAPIKeyCacheTTLEnum(enum.Enum):
key_information_cache = 600
user_information_cache = 600
global_proxy_spend = 60
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
@ -479,7 +485,7 @@ async def user_api_key_auth(
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
ttl=UserAPIKeyCacheTTLEnum.global_proxy_spend.value,
)
if global_proxy_spend is not None:
user_info = {
@ -740,7 +746,9 @@ async def user_api_key_auth(
)
for _id in user_id_information:
await user_api_key_cache.async_set_cache(
key=_id["user_id"], value=_id, ttl=600
key=_id["user_id"],
value=_id,
ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value,
)
if custom_db_client is not None:
user_id_information = await custom_db_client.get_data(
@ -961,7 +969,7 @@ async def user_api_key_auth(
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name),
value=global_proxy_spend,
ttl=60,
ttl=UserAPIKeyCacheTTLEnum.global_proxy_spend.value,
)
if global_proxy_spend is not None:
@ -993,7 +1001,9 @@ async def user_api_key_auth(
# Add hashed token to cache
await user_api_key_cache.async_set_cache(
key=api_key, value=valid_token, ttl=600
key=api_key,
value=valid_token,
ttl=UserAPIKeyCacheTTLEnum.key_information_cache.value,
)
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
@ -7308,6 +7318,7 @@ async def add_new_model(
"""
# encrypt litellm params #
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
_orignal_litellm_model_name = model_params.litellm_params.model
for k, v in _litellm_params_dict.items():
if isinstance(v, str):
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
@ -7334,6 +7345,11 @@ async def add_new_model(
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
)
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
model_name=model_params.model_name,
litellm_model_name=_orignal_litellm_model_name,
)
else:
raise HTTPException(
status_code=500,
@ -9220,24 +9236,24 @@ async def active_callbacks():
"""
global proxy_logging_obj
_alerting = str(general_settings.get("alerting"))
# get success callback
success_callback_names = []
try:
# this was returning a JSON of the values in some of the callbacks
# all we need is the callback name, hence we do str(callback)
success_callback_names = [str(x) for x in litellm.success_callback]
except:
# don't let this block the /health/readiness response, if we can't convert to str -> return litellm.success_callback
success_callback_names = litellm.success_callback
# get success callbacks
_num_callbacks = (
len(litellm.callbacks)
+ len(litellm.input_callback)
+ len(litellm.failure_callback)
+ len(litellm.success_callback)
+ len(litellm._async_failure_callback)
+ len(litellm._async_success_callback)
+ len(litellm._async_input_callback)
litellm_callbacks = [str(x) for x in litellm.callbacks]
litellm_input_callbacks = [str(x) for x in litellm.input_callback]
litellm_failure_callbacks = [str(x) for x in litellm.failure_callback]
litellm_success_callbacks = [str(x) for x in litellm.success_callback]
litellm_async_success_callbacks = [str(x) for x in litellm._async_success_callback]
litellm_async_failure_callbacks = [str(x) for x in litellm._async_failure_callback]
litellm_async_input_callbacks = [str(x) for x in litellm._async_input_callback]
all_litellm_callbacks = (
litellm_callbacks
+ litellm_input_callbacks
+ litellm_failure_callbacks
+ litellm_success_callbacks
+ litellm_async_success_callbacks
+ litellm_async_failure_callbacks
+ litellm_async_input_callbacks
)
alerting = proxy_logging_obj.alerting
@ -9247,20 +9263,15 @@ async def active_callbacks():
return {
"alerting": _alerting,
"litellm.callbacks": [str(x) for x in litellm.callbacks],
"litellm.input_callback": [str(x) for x in litellm.input_callback],
"litellm.failure_callback": [str(x) for x in litellm.failure_callback],
"litellm.success_callback": [str(x) for x in litellm.success_callback],
"litellm._async_success_callback": [
str(x) for x in litellm._async_success_callback
],
"litellm._async_failure_callback": [
str(x) for x in litellm._async_failure_callback
],
"litellm._async_input_callback": [
str(x) for x in litellm._async_input_callback
],
"num_callbacks": _num_callbacks,
"litellm.callbacks": litellm_callbacks,
"litellm.input_callback": litellm_input_callbacks,
"litellm.failure_callback": litellm_failure_callbacks,
"litellm.success_callback": litellm_success_callbacks,
"litellm._async_success_callback": litellm_async_success_callbacks,
"litellm._async_failure_callback": litellm_async_failure_callbacks,
"litellm._async_input_callback": litellm_async_input_callbacks,
"all_litellm_callbacks": all_litellm_callbacks,
"num_callbacks": len(all_litellm_callbacks),
"num_alerting": _num_alerting,
}