mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_server.py): emit webhook event whenever customer spend is tracked
Closes https://github.com/BerriAI/litellm/issues/3903
This commit is contained in:
parent
8f0019c241
commit
f729370890
4 changed files with 92 additions and 52 deletions
|
@ -684,14 +684,16 @@ class SlackAlerting(CustomLogger):
|
||||||
event: Optional[
|
event: Optional[
|
||||||
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
||||||
] = None
|
] = None
|
||||||
event_group: Optional[Literal["user", "team", "key", "proxy"]] = None
|
event_group: Optional[
|
||||||
|
Literal["internal_user", "team", "key", "proxy", "customer"]
|
||||||
|
] = None
|
||||||
event_message: str = ""
|
event_message: str = ""
|
||||||
webhook_event: Optional[WebhookEvent] = None
|
webhook_event: Optional[WebhookEvent] = None
|
||||||
if type == "proxy_budget":
|
if type == "proxy_budget":
|
||||||
event_group = "proxy"
|
event_group = "proxy"
|
||||||
event_message += "Proxy Budget: "
|
event_message += "Proxy Budget: "
|
||||||
elif type == "user_budget":
|
elif type == "user_budget":
|
||||||
event_group = "user"
|
event_group = "internal_user"
|
||||||
event_message += "User Budget: "
|
event_message += "User Budget: "
|
||||||
_id = user_info.user_id or _id
|
_id = user_info.user_id or _id
|
||||||
elif type == "team_budget":
|
elif type == "team_budget":
|
||||||
|
@ -755,6 +757,36 @@ class SlackAlerting(CustomLogger):
|
||||||
return
|
return
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def customer_spend_alert(
|
||||||
|
self,
|
||||||
|
token: Optional[str],
|
||||||
|
key_alias: Optional[str],
|
||||||
|
end_user_id: Optional[str],
|
||||||
|
response_cost: Optional[float],
|
||||||
|
max_budget: Optional[float],
|
||||||
|
):
|
||||||
|
if end_user_id is not None and token is not None and response_cost is not None:
|
||||||
|
# log customer spend
|
||||||
|
event = WebhookEvent(
|
||||||
|
spend=response_cost,
|
||||||
|
max_budget=max_budget,
|
||||||
|
token=token,
|
||||||
|
customer_id=end_user_id,
|
||||||
|
user_id=None,
|
||||||
|
team_id=None,
|
||||||
|
user_email=None,
|
||||||
|
key_alias=key_alias,
|
||||||
|
projected_exceeded_date=None,
|
||||||
|
projected_spend=None,
|
||||||
|
event="spend_tracked",
|
||||||
|
event_group="customer",
|
||||||
|
event_message="Customer spend tracked. Customer={}, spend={}".format(
|
||||||
|
end_user_id, response_cost
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send_webhook_alert(webhook_event=event)
|
||||||
|
|
||||||
def _count_outage_alerts(self, alerts: List[int]) -> str:
|
def _count_outage_alerts(self, alerts: List[int]) -> str:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
|
@ -1051,6 +1051,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
end_user_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
end_user_tpm_limit: Optional[int] = None
|
end_user_tpm_limit: Optional[int] = None
|
||||||
end_user_rpm_limit: Optional[int] = None
|
end_user_rpm_limit: Optional[int] = None
|
||||||
|
end_user_max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
|
@ -1178,6 +1179,7 @@ class CallInfo(LiteLLMBase):
|
||||||
spend: float
|
spend: float
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
token: str = Field(description="Hashed value of that key")
|
token: str = Field(description="Hashed value of that key")
|
||||||
|
customer_id: Optional[str] = None
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
team_id: Optional[str] = None
|
team_id: Optional[str] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
@ -1188,9 +1190,13 @@ class CallInfo(LiteLLMBase):
|
||||||
|
|
||||||
class WebhookEvent(CallInfo):
|
class WebhookEvent(CallInfo):
|
||||||
event: Literal[
|
event: Literal[
|
||||||
"budget_crossed", "threshold_crossed", "projected_limit_exceeded", "key_created"
|
"budget_crossed",
|
||||||
|
"threshold_crossed",
|
||||||
|
"projected_limit_exceeded",
|
||||||
|
"key_created",
|
||||||
|
"spend_tracked",
|
||||||
]
|
]
|
||||||
event_group: Literal["user", "key", "team", "proxy"]
|
event_group: Literal["internal_user", "key", "team", "proxy", "customer"]
|
||||||
event_message: str # human-readable description of event
|
event_message: str # human-readable description of event
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -717,6 +717,10 @@ async def user_api_key_auth(
|
||||||
end_user_params["end_user_rpm_limit"] = (
|
end_user_params["end_user_rpm_limit"] = (
|
||||||
budget_info.rpm_limit
|
budget_info.rpm_limit
|
||||||
)
|
)
|
||||||
|
if budget_info.max_budget is not None:
|
||||||
|
end_user_params["end_user_max_budget"] = (
|
||||||
|
budget_info.max_budget
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Unable to find user in db. Error - {}".format(str(e))
|
"Unable to find user in db. Error - {}".format(str(e))
|
||||||
|
@ -1568,6 +1572,10 @@ async def _PROXY_track_cost_callback(
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
|
||||||
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
|
||||||
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
|
||||||
|
key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None)
|
||||||
|
end_user_max_budget = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_end_user_max_budget", None
|
||||||
|
)
|
||||||
if kwargs.get("response_cost", None) is not None:
|
if kwargs.get("response_cost", None) is not None:
|
||||||
response_cost = kwargs["response_cost"]
|
response_cost = kwargs["response_cost"]
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
@ -1604,6 +1612,14 @@ async def _PROXY_track_cost_callback(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||||
|
token=user_api_key,
|
||||||
|
key_alias=key_alias,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
response_cost=response_cost,
|
||||||
|
max_budget=end_user_max_budget,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"User API key and team id and user id missing from custom callback."
|
"User API key and team id and user id missing from custom callback."
|
||||||
|
@ -3959,6 +3975,10 @@ async def chat_completion(
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["user_api_end_user_max_budget"] = getattr(
|
||||||
|
user_api_key_dict, "end_user_max_budget", None
|
||||||
|
)
|
||||||
|
|
||||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
"global_max_parallel_requests", None
|
"global_max_parallel_requests", None
|
||||||
)
|
)
|
||||||
|
@ -13102,54 +13122,6 @@ async def token_generate():
|
||||||
return {"token": token}
|
return {"token": token}
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/update_database", dependencies=[Depends(user_api_key_auth)])
|
|
||||||
# async def update_database_endpoint(
|
|
||||||
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
# ):
|
|
||||||
# """
|
|
||||||
# Test endpoint. DO NOT MERGE IN PROD.
|
|
||||||
|
|
||||||
# Used for isolating and testing our prisma db update logic in high-traffic.
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{time.time()}"
|
|
||||||
# resp = litellm.ModelResponse(
|
|
||||||
# id=request_id,
|
|
||||||
# choices=[
|
|
||||||
# litellm.Choices(
|
|
||||||
# finish_reason=None,
|
|
||||||
# index=0,
|
|
||||||
# message=litellm.Message(
|
|
||||||
# content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
|
||||||
# role="assistant",
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# ],
|
|
||||||
# model="gpt-35-turbo", # azure always has model written like this
|
|
||||||
# usage=litellm.Usage(
|
|
||||||
# prompt_tokens=210, completion_tokens=200, total_tokens=410
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# await _PROXY_track_cost_callback(
|
|
||||||
# kwargs={
|
|
||||||
# "model": "chatgpt-v-2",
|
|
||||||
# "stream": False,
|
|
||||||
# "litellm_params": {
|
|
||||||
# "metadata": {
|
|
||||||
# "user_api_key": user_api_key_dict.token,
|
|
||||||
# "user_api_key_user_id": user_api_key_dict.user_id,
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "response_cost": 0.00002,
|
|
||||||
# },
|
|
||||||
# completion_response=resp,
|
|
||||||
# start_time=datetime.now(),
|
|
||||||
# end_time=datetime.now(),
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# raise e
|
|
||||||
|
|
||||||
|
|
||||||
def _has_user_setup_sso():
|
def _has_user_setup_sso():
|
||||||
"""
|
"""
|
||||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
|
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
|
||||||
|
|
|
@ -499,6 +499,36 @@ async def test_webhook_alerting(alerting_type):
|
||||||
mock_send_alert.assert_awaited_once()
|
mock_send_alert.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_webhook_customer_spend_event():
|
||||||
|
# """
|
||||||
|
# Test if customer spend is working as expected
|
||||||
|
# """
|
||||||
|
# slack_alerting = SlackAlerting(alerting=["webhook"])
|
||||||
|
|
||||||
|
# with patch.object(
|
||||||
|
# slack_alerting, "send_webhook_alert", new=AsyncMock()
|
||||||
|
# ) as mock_send_alert:
|
||||||
|
# user_info = {
|
||||||
|
# "token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
|
||||||
|
# "spend": 1,
|
||||||
|
# "max_budget": 0,
|
||||||
|
# "user_id": "ishaan@berri.ai",
|
||||||
|
# "user_email": "ishaan@berri.ai",
|
||||||
|
# "key_alias": "my-test-key",
|
||||||
|
# "projected_exceeded_date": "10/20/2024",
|
||||||
|
# "projected_spend": 200,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# user_info = CallInfo(**user_info)
|
||||||
|
# for _ in range(50):
|
||||||
|
# await slack_alerting.budget_alerts(
|
||||||
|
# type=alerting_type,
|
||||||
|
# user_info=user_info,
|
||||||
|
# )
|
||||||
|
# mock_send_alert.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, api_base, llm_provider, vertex_project, vertex_location",
|
"model, api_base, llm_provider, vertex_project, vertex_location",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue