feat(proxy_server.py): emit webhook event whenever customer spend is tracked

Closes https://github.com/BerriAI/litellm/issues/3903
This commit is contained in:
Krrish Dholakia 2024-05-29 15:59:32 -07:00
parent 8f0019c241
commit f729370890
4 changed files with 92 additions and 52 deletions

View file

@ -684,14 +684,16 @@ class SlackAlerting(CustomLogger):
event: Optional[ event: Optional[
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"] Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
] = None ] = None
event_group: Optional[Literal["user", "team", "key", "proxy"]] = None event_group: Optional[
Literal["internal_user", "team", "key", "proxy", "customer"]
] = None
event_message: str = "" event_message: str = ""
webhook_event: Optional[WebhookEvent] = None webhook_event: Optional[WebhookEvent] = None
if type == "proxy_budget": if type == "proxy_budget":
event_group = "proxy" event_group = "proxy"
event_message += "Proxy Budget: " event_message += "Proxy Budget: "
elif type == "user_budget": elif type == "user_budget":
event_group = "user" event_group = "internal_user"
event_message += "User Budget: " event_message += "User Budget: "
_id = user_info.user_id or _id _id = user_info.user_id or _id
elif type == "team_budget": elif type == "team_budget":
@ -755,6 +757,36 @@ class SlackAlerting(CustomLogger):
return return
return return
async def customer_spend_alert(
self,
token: Optional[str],
key_alias: Optional[str],
end_user_id: Optional[str],
response_cost: Optional[float],
max_budget: Optional[float],
):
if end_user_id is not None and token is not None and response_cost is not None:
# log customer spend
event = WebhookEvent(
spend=response_cost,
max_budget=max_budget,
token=token,
customer_id=end_user_id,
user_id=None,
team_id=None,
user_email=None,
key_alias=key_alias,
projected_exceeded_date=None,
projected_spend=None,
event="spend_tracked",
event_group="customer",
event_message="Customer spend tracked. Customer={}, spend={}".format(
end_user_id, response_cost
),
)
await self.send_webhook_alert(webhook_event=event)
def _count_outage_alerts(self, alerts: List[int]) -> str: def _count_outage_alerts(self, alerts: List[int]) -> str:
""" """
Parameters: Parameters:

View file

@ -1051,6 +1051,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
end_user_tpm_limit: Optional[int] = None end_user_tpm_limit: Optional[int] = None
end_user_rpm_limit: Optional[int] = None end_user_rpm_limit: Optional[int] = None
end_user_max_budget: Optional[float] = None
class UserAPIKeyAuth( class UserAPIKeyAuth(
@ -1178,6 +1179,7 @@ class CallInfo(LiteLLMBase):
spend: float spend: float
max_budget: Optional[float] = None max_budget: Optional[float] = None
token: str = Field(description="Hashed value of that key") token: str = Field(description="Hashed value of that key")
customer_id: Optional[str] = None
user_id: Optional[str] = None user_id: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@ -1188,9 +1190,13 @@ class CallInfo(LiteLLMBase):
class WebhookEvent(CallInfo): class WebhookEvent(CallInfo):
event: Literal[ event: Literal[
"budget_crossed", "threshold_crossed", "projected_limit_exceeded", "key_created" "budget_crossed",
"threshold_crossed",
"projected_limit_exceeded",
"key_created",
"spend_tracked",
] ]
event_group: Literal["user", "key", "team", "proxy"] event_group: Literal["internal_user", "key", "team", "proxy", "customer"]
event_message: str # human-readable description of event event_message: str # human-readable description of event

View file

@ -717,6 +717,10 @@ async def user_api_key_auth(
end_user_params["end_user_rpm_limit"] = ( end_user_params["end_user_rpm_limit"] = (
budget_info.rpm_limit budget_info.rpm_limit
) )
if budget_info.max_budget is not None:
end_user_params["end_user_max_budget"] = (
budget_info.max_budget
)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Unable to find user in db. Error - {}".format(str(e)) "Unable to find user in db. Error - {}".format(str(e))
@ -1568,6 +1572,10 @@ async def _PROXY_track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None) org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None)
end_user_max_budget = kwargs["litellm_params"]["metadata"].get(
"user_api_end_user_max_budget", None
)
if kwargs.get("response_cost", None) is not None: if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"] response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get( user_api_key = kwargs["litellm_params"]["metadata"].get(
@ -1604,6 +1612,14 @@ async def _PROXY_track_cost_callback(
end_user_id=end_user_id, end_user_id=end_user_id,
response_cost=response_cost, response_cost=response_cost,
) )
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else: else:
raise Exception( raise Exception(
"User API key and team id and user id missing from custom callback." "User API key and team id and user id missing from custom callback."
@ -3959,6 +3975,10 @@ async def chat_completion(
data["metadata"]["user_api_key_alias"] = getattr( data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None user_api_key_dict, "key_alias", None
) )
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get( data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None "global_max_parallel_requests", None
) )
@ -13102,54 +13122,6 @@ async def token_generate():
return {"token": token} return {"token": token}
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
# async def update_database_endpoint(
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
# ):
# """
# Test endpoint. DO NOT MERGE IN PROD.
# Used for isolating and testing our prisma db update logic in high-traffic.
# """
# try:
# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}"
# resp = litellm.ModelResponse(
# id=request_id,
# choices=[
# litellm.Choices(
# finish_reason=None,
# index=0,
# message=litellm.Message(
# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
# role="assistant",
# ),
# )
# ],
# model="gpt-35-turbo", # azure always has model written like this
# usage=litellm.Usage(
# prompt_tokens=210, completion_tokens=200, total_tokens=410
# ),
# )
# await _PROXY_track_cost_callback(
# kwargs={
# "model": "chatgpt-v-2",
# "stream": False,
# "litellm_params": {
# "metadata": {
# "user_api_key": user_api_key_dict.token,
# "user_api_key_user_id": user_api_key_dict.user_id,
# }
# },
# "response_cost": 0.00002,
# },
# completion_response=resp,
# start_time=datetime.now(),
# end_time=datetime.now(),
# )
# except Exception as e:
# raise e
def _has_user_setup_sso(): def _has_user_setup_sso():
""" """
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables. Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.

View file

@ -499,6 +499,36 @@ async def test_webhook_alerting(alerting_type):
mock_send_alert.assert_awaited_once() mock_send_alert.assert_awaited_once()
# @pytest.mark.asyncio
# async def test_webhook_customer_spend_event():
# """
# Test if customer spend is working as expected
# """
# slack_alerting = SlackAlerting(alerting=["webhook"])
# with patch.object(
# slack_alerting, "send_webhook_alert", new=AsyncMock()
# ) as mock_send_alert:
# user_info = {
# "token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
# "spend": 1,
# "max_budget": 0,
# "user_id": "ishaan@berri.ai",
# "user_email": "ishaan@berri.ai",
# "key_alias": "my-test-key",
# "projected_exceeded_date": "10/20/2024",
# "projected_spend": 200,
# }
# user_info = CallInfo(**user_info)
# for _ in range(50):
# await slack_alerting.budget_alerts(
# type=alerting_type,
# user_info=user_info,
# )
# mock_send_alert.assert_awaited_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, api_base, llm_provider, vertex_project, vertex_location", "model, api_base, llm_provider, vertex_project, vertex_location",
[ [