forked from phoenix/litellm-mirror
(perf improvement proxy) use one redis set cache to update spend in db (30-40% perf improvement) (#5960)
* use one set op to update spend in db * fix test_team_cache_update_called
This commit is contained in:
parent
8bf7573fd8
commit
eb325cce7d
2 changed files with 46 additions and 35 deletions
|
@ -18,6 +18,7 @@ from typing import (
|
|||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
|
@ -119,7 +120,10 @@ from litellm.integrations.SlackAlerting.slack_alerting import (
|
|||
SlackAlerting,
|
||||
SlackAlertingArgs,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
|
@ -771,6 +775,7 @@ async def _PROXY_track_cost_callback(
|
|||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
|
@ -808,12 +813,16 @@ async def _PROXY_track_cost_callback(
|
|||
org_id=org_id,
|
||||
)
|
||||
|
||||
await update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
|
@ -1065,6 +1074,7 @@ async def update_cache(
|
|||
end_user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
response_cost: Optional[float],
|
||||
parent_otel_span: Optional[Span],
|
||||
):
|
||||
"""
|
||||
Use this to update the cache with new user spend.
|
||||
|
@ -1072,6 +1082,8 @@ async def update_cache(
|
|||
Put any alerting logic in here.
|
||||
"""
|
||||
|
||||
values_to_update_in_cache: List[Tuple[Any, Any]] = []
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
async def _update_key_cache(token: str, response_cost: float):
|
||||
# Fetch the existing cost for the given token
|
||||
|
@ -1157,9 +1169,7 @@ async def update_cache(
|
|||
|
||||
# Update the cost column for the given token
|
||||
existing_spend_obj.spend = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=hashed_token, value=existing_spend_obj
|
||||
)
|
||||
values_to_update_in_cache.append((hashed_token, existing_spend_obj))
|
||||
|
||||
### UPDATE USER SPEND ###
|
||||
async def _update_user_cache():
|
||||
|
@ -1188,14 +1198,10 @@ async def update_cache(
|
|||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj))
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj.json()
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj.json()))
|
||||
## UPDATE GLOBAL PROXY ##
|
||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name)
|
||||
|
@ -1205,8 +1211,8 @@ async def update_cache(
|
|||
return
|
||||
elif response_cost is not None and global_proxy_spend is not None:
|
||||
increment = global_proxy_spend + response_cost
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key="{}:spend".format(litellm_proxy_admin_name), value=increment
|
||||
values_to_update_in_cache.append(
|
||||
("{}:spend".format(litellm_proxy_admin_name), increment)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -1242,14 +1248,10 @@ async def update_cache(
|
|||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj))
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj.json()
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj.json()))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"An error occurred updating end user cache: {str(e)}"
|
||||
|
@ -1288,30 +1290,34 @@ async def update_cache(
|
|||
# Update the cost column for the given user
|
||||
if isinstance(existing_spend_obj, dict):
|
||||
existing_spend_obj["spend"] = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj))
|
||||
else:
|
||||
existing_spend_obj.spend = new_spend
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=_id, value=existing_spend_obj
|
||||
)
|
||||
values_to_update_in_cache.append((_id, existing_spend_obj))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"An error occurred updating end user cache: {str(e)}"
|
||||
)
|
||||
|
||||
if token is not None and response_cost is not None:
|
||||
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
|
||||
await _update_key_cache(token=token, response_cost=response_cost)
|
||||
|
||||
if user_id is not None:
|
||||
asyncio.create_task(_update_user_cache())
|
||||
await _update_user_cache()
|
||||
|
||||
if end_user_id is not None:
|
||||
asyncio.create_task(_update_end_user_cache())
|
||||
await _update_end_user_cache()
|
||||
|
||||
if team_id is not None:
|
||||
asyncio.create_task(_update_team_cache())
|
||||
await _update_team_cache()
|
||||
|
||||
asyncio.create_task(
|
||||
user_api_key_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_ollama_serve():
|
||||
|
|
|
@ -838,7 +838,12 @@ async def test_team_cache_update_called():
|
|||
cache.async_get_cache = mock_call_cache
|
||||
# Call the function under test
|
||||
await litellm.proxy.proxy_server.update_cache(
|
||||
token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20
|
||||
token=None,
|
||||
user_id=None,
|
||||
end_user_id=None,
|
||||
team_id="1234",
|
||||
response_cost=20,
|
||||
parent_otel_span=None,
|
||||
) # type: ignore
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue