diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fdd5feda93..fdc3263052 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -523,19 +523,27 @@ async def track_cost_callback( verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request") or {} + user_id = proxy_server_request.get("body", {}).get("user", None) if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response = kwargs["complete_streaming_response"] response_cost = litellm.completion_cost( completion_response=completion_response ) - verbose_proxy_logger.debug(f"streaming response_cost {response_cost}") + user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) - user_id = kwargs["litellm_params"]["metadata"].get( + + user_id = user_id or kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + + verbose_proxy_logger.debug( + f"streaming response_cost {response_cost}, for user_id {user_id}" + ) if user_api_key and ( prisma_client is not None or custom_db_client is not None ): @@ -555,9 +563,12 @@ async def track_cost_callback( user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) - user_id = kwargs["litellm_params"]["metadata"].get( + user_id = user_id or kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + verbose_proxy_logger.debug( + f"response_cost {response_cost}, for user_id {user_id}" + ) if user_api_key and ( prisma_client is not None or custom_db_client is not None ): diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 23b66f22d7..9f183644da 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -530,7 +530,11 @@ class PrismaClient: where={"token": token}, # type: ignore data={**db_data}, # type: ignore ) - print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") + print_verbose( + "\033[91m" + + f"DB Token Table update succeeded {response}" + + "\033[0m" + ) return {"token": token, "data": db_data} elif user_id is not None: """ @@ -540,6 +544,23 @@ class PrismaClient: where={"user_id": user_id}, # type: ignore data={**db_data}, # type: ignore ) + if update_user_row is None: + # if the provided user does not exist, STILL Track this! + # make a new user with {"user_id": user_id, "spend": data['spend']} + + db_data["user_id"] = user_id + update_user_row = await self.db.litellm_usertable.upsert( + where={"user_id": user_id}, # type: ignore + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + print_verbose( + "\033[91m" + + f"DB User Table - update succeeded {update_user_row}" + + "\033[0m" + ) return {"user_id": user_id, "data": db_data} except Exception as e: asyncio.create_task( diff --git a/litellm/utils.py b/litellm/utils.py index f7cc5d2a54..2e0650be6e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3305,6 +3305,8 @@ def get_optional_params( unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: + if k == "user": + continue if k == "n" and n == 1: # langchain sends n=1 as a default value continue # skip this param if (