diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5848fe451..893ca4507 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -18,6 +18,7 @@ from typing import ( Any, List, Optional, + Tuple, get_args, get_origin, get_type_hints, @@ -119,7 +120,10 @@ from litellm.integrations.SlackAlerting.slack_alerting import ( SlackAlerting, SlackAlertingArgs, ) -from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import * from litellm.proxy.analytics_endpoints.analytics_endpoints import ( @@ -771,6 +775,7 @@ async def _PROXY_track_cost_callback( verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} end_user_id = proxy_server_request.get("body", {}).get("user", None) @@ -808,12 +813,16 @@ async def _PROXY_track_cost_callback( org_id=org_id, ) - await update_cache( - token=user_api_key, - user_id=user_id, - end_user_id=end_user_id, - response_cost=response_cost, - team_id=team_id, + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + ) ) await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( @@ -1065,6 +1074,7 @@ async def update_cache( end_user_id: Optional[str], team_id: Optional[str], response_cost: Optional[float], + parent_otel_span: Optional[Span], ): """ Use this to update the cache with new user spend. @@ -1072,6 +1082,8 @@ async def update_cache( Put any alerting logic in here. """ + values_to_update_in_cache: List[Tuple[Any, Any]] = [] + ### UPDATE KEY SPEND ### async def _update_key_cache(token: str, response_cost: float): # Fetch the existing cost for the given token @@ -1157,9 +1169,7 @@ async def update_cache( # Update the cost column for the given token existing_spend_obj.spend = new_spend - await user_api_key_cache.async_set_cache( - key=hashed_token, value=existing_spend_obj - ) + values_to_update_in_cache.append((hashed_token, existing_spend_obj)) ### UPDATE USER SPEND ### async def _update_user_cache(): @@ -1188,14 +1198,10 @@ async def update_cache( # Update the cost column for the given user if isinstance(existing_spend_obj, dict): existing_spend_obj["spend"] = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj - ) + values_to_update_in_cache.append((_id, existing_spend_obj)) else: existing_spend_obj.spend = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj.json() - ) + values_to_update_in_cache.append((_id, existing_spend_obj.json())) ## UPDATE GLOBAL PROXY ## global_proxy_spend = await user_api_key_cache.async_get_cache( key="{}:spend".format(litellm_proxy_admin_name) @@ -1205,8 +1211,8 @@ async def update_cache( return elif response_cost is not None and global_proxy_spend is not None: increment = global_proxy_spend + response_cost - await user_api_key_cache.async_set_cache( - key="{}:spend".format(litellm_proxy_admin_name), value=increment + values_to_update_in_cache.append( + ("{}:spend".format(litellm_proxy_admin_name), increment) ) except Exception as e: verbose_proxy_logger.debug( @@ -1242,14 +1248,10 @@ async def update_cache( # Update the cost column for the given user if isinstance(existing_spend_obj, dict): existing_spend_obj["spend"] = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj - ) + values_to_update_in_cache.append((_id, existing_spend_obj)) else: existing_spend_obj.spend = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj.json() - ) + values_to_update_in_cache.append((_id, existing_spend_obj.json())) except Exception as e: verbose_proxy_logger.exception( f"An error occurred updating end user cache: {str(e)}" @@ -1288,30 +1290,34 @@ async def update_cache( # Update the cost column for the given user if isinstance(existing_spend_obj, dict): existing_spend_obj["spend"] = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj - ) + values_to_update_in_cache.append((_id, existing_spend_obj)) else: existing_spend_obj.spend = new_spend - await user_api_key_cache.async_set_cache( - key=_id, value=existing_spend_obj - ) + values_to_update_in_cache.append((_id, existing_spend_obj)) except Exception as e: verbose_proxy_logger.exception( f"An error occurred updating end user cache: {str(e)}" ) if token is not None and response_cost is not None: - asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost)) + await _update_key_cache(token=token, response_cost=response_cost) if user_id is not None: - asyncio.create_task(_update_user_cache()) + await _update_user_cache() if end_user_id is not None: - asyncio.create_task(_update_end_user_cache()) + await _update_end_user_cache() if team_id is not None: - asyncio.create_task(_update_team_cache()) + await _update_team_cache() + + asyncio.create_task( + user_api_key_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=parent_otel_span, + ) + ) def run_ollama_serve(): diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 51bf55c9c..8b1f9ac0d 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -838,7 +838,12 @@ async def test_team_cache_update_called(): cache.async_get_cache = mock_call_cache # Call the function under test await litellm.proxy.proxy_server.update_cache( - token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20 + token=None, + user_id=None, + end_user_id=None, + team_id="1234", + response_cost=20, + parent_otel_span=None, ) # type: ignore await asyncio.sleep(3)