From f4aef679c9588b6a23dfd7678ce57c5c5bf5820a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 18 Jan 2024 17:44:39 -0800 Subject: [PATCH 1/4] (feat) proxy - track cost for user_ids that do not exist --- litellm/proxy/utils.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 23b66f22d7..9f183644da 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -530,7 +530,11 @@ class PrismaClient: where={"token": token}, # type: ignore data={**db_data}, # type: ignore ) - print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") + print_verbose( + "\033[91m" + + f"DB Token Table update succeeded {response}" + + "\033[0m" + ) return {"token": token, "data": db_data} elif user_id is not None: """ @@ -540,6 +544,23 @@ class PrismaClient: where={"user_id": user_id}, # type: ignore data={**db_data}, # type: ignore ) + if update_user_row is None: + # if the provided user does not exist, STILL Track this! + # make a new user with {"user_id": user_id, "spend": data['spend']} + + db_data["user_id"] = user_id + update_user_row = await self.db.litellm_usertable.upsert( + where={"user_id": user_id}, # type: ignore + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + print_verbose( + "\033[91m" + + f"DB User Table - update succeeded {update_user_row}" + + "\033[0m" + ) return {"user_id": user_id, "data": db_data} except Exception as e: asyncio.create_task( From d8d1cea69f7bbe3a6364c6d67225dbc501656c02 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 18 Jan 2024 17:45:59 -0800 Subject: [PATCH 2/4] (feat) support user param for all providers --- litellm/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index f7cc5d2a54..2e0650be6e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3305,6 +3305,8 @@ def get_optional_params( unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: + if k == "user": + continue if k == "n" and n == 1: # langchain sends n=1 as a default value continue # skip this param if ( From 9e952ef7e92250498b7d2237e3636222b8f2a34c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 18 Jan 2024 17:51:48 -0800 Subject: [PATCH 3/4] (feat) use user_id passed to request - cost track --- litellm/proxy/proxy_server.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5d55a7161b..9e6589c482 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -523,19 +523,27 @@ async def track_cost_callback( verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) + litellm_params = kwargs.get("litellm_params", {}) + proxy_server_request = litellm_params.get("proxy_server_request") + user_id = proxy_server_request.get("body", {}).get("user", None) if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response = kwargs["complete_streaming_response"] response_cost = litellm.completion_cost( completion_response=completion_response ) - verbose_proxy_logger.debug(f"streaming response_cost {response_cost}") + user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) - user_id = kwargs["litellm_params"]["metadata"].get( + + user_id = user_id or kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + + verbose_proxy_logger.debug( + f"streaming response_cost {response_cost}, for user_id {user_id}" + ) if user_api_key and ( prisma_client is not None or custom_db_client is not None ): @@ -555,9 +563,12 @@ async def track_cost_callback( user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) - user_id = kwargs["litellm_params"]["metadata"].get( + user_id = user_id or kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + verbose_proxy_logger.debug( + f"response_cost {response_cost}, for user_id {user_id}" + ) if user_api_key and ( prisma_client is not None or custom_db_client is not None ): From c4133314a5d547709801858fbfb6e905ba5846f6 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 18 Jan 2024 18:05:51 -0800 Subject: [PATCH 4/4] (fix) safe access litellm_params, proxy_server_request --- litellm/proxy/proxy_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9e6589c482..297d0eed18 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -523,8 +523,8 @@ async def track_cost_callback( verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) - litellm_params = kwargs.get("litellm_params", {}) - proxy_server_request = litellm_params.get("proxy_server_request") + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request") or {} user_id = proxy_server_request.get("body", {}).get("user", None) if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost