(perf improvement proxy) use one redis set cache to update spend in db (30-40% perf improvement) (#5960)

* use one set op to update spend in db

* fix test_team_cache_update_called
This commit is contained in:
Ishaan Jaff 2024-09-28 13:00:31 -07:00 committed by GitHub
parent 8bf7573fd8
commit eb325cce7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 46 additions and 35 deletions

View file

@ -18,6 +18,7 @@ from typing import (
Any,
List,
Optional,
Tuple,
get_args,
get_origin,
get_type_hints,
@ -119,7 +120,10 @@ from litellm.integrations.SlackAlerting.slack_alerting import (
SlackAlerting,
SlackAlertingArgs,
)
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
@ -771,6 +775,7 @@ async def _PROXY_track_cost_callback(
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
@ -808,12 +813,16 @@ async def _PROXY_track_cost_callback(
org_id=org_id,
)
await update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
@ -1065,6 +1074,7 @@ async def update_cache(
end_user_id: Optional[str],
team_id: Optional[str],
response_cost: Optional[float],
parent_otel_span: Optional[Span],
):
"""
Use this to update the cache with new user spend.
@ -1072,6 +1082,8 @@ async def update_cache(
Put any alerting logic in here.
"""
values_to_update_in_cache: List[Tuple[Any, Any]] = []
### UPDATE KEY SPEND ###
async def _update_key_cache(token: str, response_cost: float):
# Fetch the existing cost for the given token
@ -1157,9 +1169,7 @@ async def update_cache(
# Update the cost column for the given token
existing_spend_obj.spend = new_spend
await user_api_key_cache.async_set_cache(
key=hashed_token, value=existing_spend_obj
)
values_to_update_in_cache.append((hashed_token, existing_spend_obj))
### UPDATE USER SPEND ###
async def _update_user_cache():
@ -1188,14 +1198,10 @@ async def update_cache(
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj
)
values_to_update_in_cache.append((_id, existing_spend_obj))
else:
existing_spend_obj.spend = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj.json()
)
values_to_update_in_cache.append((_id, existing_spend_obj.json()))
## UPDATE GLOBAL PROXY ##
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
@ -1205,8 +1211,8 @@ async def update_cache(
return
elif response_cost is not None and global_proxy_spend is not None:
increment = global_proxy_spend + response_cost
await user_api_key_cache.async_set_cache(
key="{}:spend".format(litellm_proxy_admin_name), value=increment
values_to_update_in_cache.append(
("{}:spend".format(litellm_proxy_admin_name), increment)
)
except Exception as e:
verbose_proxy_logger.debug(
@ -1242,14 +1248,10 @@ async def update_cache(
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj
)
values_to_update_in_cache.append((_id, existing_spend_obj))
else:
existing_spend_obj.spend = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj.json()
)
values_to_update_in_cache.append((_id, existing_spend_obj.json()))
except Exception as e:
verbose_proxy_logger.exception(
f"An error occurred updating end user cache: {str(e)}"
@ -1288,30 +1290,34 @@ async def update_cache(
# Update the cost column for the given user
if isinstance(existing_spend_obj, dict):
existing_spend_obj["spend"] = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj
)
values_to_update_in_cache.append((_id, existing_spend_obj))
else:
existing_spend_obj.spend = new_spend
await user_api_key_cache.async_set_cache(
key=_id, value=existing_spend_obj
)
values_to_update_in_cache.append((_id, existing_spend_obj))
except Exception as e:
verbose_proxy_logger.exception(
f"An error occurred updating end user cache: {str(e)}"
)
if token is not None and response_cost is not None:
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
await _update_key_cache(token=token, response_cost=response_cost)
if user_id is not None:
asyncio.create_task(_update_user_cache())
await _update_user_cache()
if end_user_id is not None:
asyncio.create_task(_update_end_user_cache())
await _update_end_user_cache()
if team_id is not None:
asyncio.create_task(_update_team_cache())
await _update_team_cache()
asyncio.create_task(
user_api_key_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=parent_otel_span,
)
)
def run_ollama_serve():