From f729370890e36c1f23ccf73839a5924e69e18766 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 29 May 2024 15:59:32 -0700 Subject: [PATCH] feat(proxy_server.py): emit webhook event whenever customer spend is tracked Closes https://github.com/BerriAI/litellm/issues/3903 --- litellm/integrations/slack_alerting.py | 36 +++++++++++++- litellm/proxy/_types.py | 10 +++- litellm/proxy/proxy_server.py | 68 ++++++++------------------ litellm/tests/test_alerting.py | 30 ++++++++++++ 4 files changed, 92 insertions(+), 52 deletions(-) diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index 8fa1be8915..c1fdc553b9 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -684,14 +684,16 @@ class SlackAlerting(CustomLogger): event: Optional[ Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"] ] = None - event_group: Optional[Literal["user", "team", "key", "proxy"]] = None + event_group: Optional[ + Literal["internal_user", "team", "key", "proxy", "customer"] + ] = None event_message: str = "" webhook_event: Optional[WebhookEvent] = None if type == "proxy_budget": event_group = "proxy" event_message += "Proxy Budget: " elif type == "user_budget": - event_group = "user" + event_group = "internal_user" event_message += "User Budget: " _id = user_info.user_id or _id elif type == "team_budget": @@ -755,6 +757,36 @@ class SlackAlerting(CustomLogger): return return + async def customer_spend_alert( + self, + token: Optional[str], + key_alias: Optional[str], + end_user_id: Optional[str], + response_cost: Optional[float], + max_budget: Optional[float], + ): + if end_user_id is not None and token is not None and response_cost is not None: + # log customer spend + event = WebhookEvent( + spend=response_cost, + max_budget=max_budget, + token=token, + customer_id=end_user_id, + user_id=None, + team_id=None, + user_email=None, + key_alias=key_alias, + projected_exceeded_date=None, + projected_spend=None, + event="spend_tracked", + event_group="customer", + event_message="Customer spend tracked. Customer={}, spend={}".format( + end_user_id, response_cost + ), + ) + + await self.send_webhook_alert(webhook_event=event) + def _count_outage_alerts(self, alerts: List[int]) -> str: """ Parameters: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 07812a756d..555254a633 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1051,6 +1051,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): end_user_id: Optional[str] = None end_user_tpm_limit: Optional[int] = None end_user_rpm_limit: Optional[int] = None + end_user_max_budget: Optional[float] = None class UserAPIKeyAuth( @@ -1178,6 +1179,7 @@ class CallInfo(LiteLLMBase): spend: float max_budget: Optional[float] = None token: str = Field(description="Hashed value of that key") + customer_id: Optional[str] = None user_id: Optional[str] = None team_id: Optional[str] = None user_email: Optional[str] = None @@ -1188,9 +1190,13 @@ class CallInfo(LiteLLMBase): class WebhookEvent(CallInfo): event: Literal[ - "budget_crossed", "threshold_crossed", "projected_limit_exceeded", "key_created" + "budget_crossed", + "threshold_crossed", + "projected_limit_exceeded", + "key_created", + "spend_tracked", ] - event_group: Literal["user", "key", "team", "proxy"] + event_group: Literal["internal_user", "key", "team", "proxy", "customer"] event_message: str # human-readable description of event diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6efa150a44..2e9c256658 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -717,6 +717,10 @@ async def user_api_key_auth( end_user_params["end_user_rpm_limit"] = ( budget_info.rpm_limit ) + if budget_info.max_budget is not None: + end_user_params["end_user_max_budget"] = ( + budget_info.max_budget + ) except Exception as e: verbose_proxy_logger.debug( "Unable to find user in db. Error - {}".format(str(e)) @@ -1568,6 +1572,10 @@ async def _PROXY_track_cost_callback( user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None) + key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None) + end_user_max_budget = kwargs["litellm_params"]["metadata"].get( + "user_api_end_user_max_budget", None + ) if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] user_api_key = kwargs["litellm_params"]["metadata"].get( @@ -1604,6 +1612,14 @@ async def _PROXY_track_cost_callback( end_user_id=end_user_id, response_cost=response_cost, ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) else: raise Exception( "User API key and team id and user id missing from custom callback." @@ -3959,6 +3975,10 @@ async def chat_completion( data["metadata"]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) + data["metadata"]["user_api_end_user_max_budget"] = getattr( + user_api_key_dict, "end_user_max_budget", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( "global_max_parallel_requests", None ) @@ -13102,54 +13122,6 @@ async def token_generate(): return {"token": token} -# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)]) -# async def update_database_endpoint( -# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -# ): -# """ -# Test endpoint. DO NOT MERGE IN PROD. - -# Used for isolating and testing our prisma db update logic in high-traffic. -# """ -# try: -# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}" -# resp = litellm.ModelResponse( -# id=request_id, -# choices=[ -# litellm.Choices( -# finish_reason=None, -# index=0, -# message=litellm.Message( -# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", -# role="assistant", -# ), -# ) -# ], -# model="gpt-35-turbo", # azure always has model written like this -# usage=litellm.Usage( -# prompt_tokens=210, completion_tokens=200, total_tokens=410 -# ), -# ) -# await _PROXY_track_cost_callback( -# kwargs={ -# "model": "chatgpt-v-2", -# "stream": False, -# "litellm_params": { -# "metadata": { -# "user_api_key": user_api_key_dict.token, -# "user_api_key_user_id": user_api_key_dict.user_id, -# } -# }, -# "response_cost": 0.00002, -# }, -# completion_response=resp, -# start_time=datetime.now(), -# end_time=datetime.now(), -# ) -# except Exception as e: -# raise e - - def _has_user_setup_sso(): """ Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables. diff --git a/litellm/tests/test_alerting.py b/litellm/tests/test_alerting.py index 4bb2bd9369..19e0e719ff 100644 --- a/litellm/tests/test_alerting.py +++ b/litellm/tests/test_alerting.py @@ -499,6 +499,36 @@ async def test_webhook_alerting(alerting_type): mock_send_alert.assert_awaited_once() +# @pytest.mark.asyncio +# async def test_webhook_customer_spend_event(): +# """ +# Test if customer spend is working as expected +# """ +# slack_alerting = SlackAlerting(alerting=["webhook"]) + +# with patch.object( +# slack_alerting, "send_webhook_alert", new=AsyncMock() +# ) as mock_send_alert: +# user_info = { +# "token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403", +# "spend": 1, +# "max_budget": 0, +# "user_id": "ishaan@berri.ai", +# "user_email": "ishaan@berri.ai", +# "key_alias": "my-test-key", +# "projected_exceeded_date": "10/20/2024", +# "projected_spend": 200, +# } + +# user_info = CallInfo(**user_info) +# for _ in range(50): +# await slack_alerting.budget_alerts( +# type=alerting_type, +# user_info=user_info, +# ) +# mock_send_alert.assert_awaited_once() + + @pytest.mark.parametrize( "model, api_base, llm_provider, vertex_project, vertex_location", [