mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #1509 from BerriAI/litellm_track_cost_user_id_chat_completions
[Feat] Proxy - Track Cost Per User (Using `user` passed to requests)
This commit is contained in:
commit
79e261f533
3 changed files with 38 additions and 4 deletions
|
@ -523,19 +523,27 @@ async def track_cost_callback(
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||||
)
|
)
|
||||||
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
|
user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
completion_response = kwargs["complete_streaming_response"]
|
completion_response = kwargs["complete_streaming_response"]
|
||||||
response_cost = litellm.completion_cost(
|
response_cost = litellm.completion_cost(
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"streaming response_cost {response_cost}")
|
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key", None
|
"user_api_key", None
|
||||||
)
|
)
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
|
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"streaming response_cost {response_cost}, for user_id {user_id}"
|
||||||
|
)
|
||||||
if user_api_key and (
|
if user_api_key and (
|
||||||
prisma_client is not None or custom_db_client is not None
|
prisma_client is not None or custom_db_client is not None
|
||||||
):
|
):
|
||||||
|
@ -555,9 +563,12 @@ async def track_cost_callback(
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key", None
|
"user_api_key", None
|
||||||
)
|
)
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"response_cost {response_cost}, for user_id {user_id}"
|
||||||
|
)
|
||||||
if user_api_key and (
|
if user_api_key and (
|
||||||
prisma_client is not None or custom_db_client is not None
|
prisma_client is not None or custom_db_client is not None
|
||||||
):
|
):
|
||||||
|
|
|
@ -530,7 +530,11 @@ class PrismaClient:
|
||||||
where={"token": token}, # type: ignore
|
where={"token": token}, # type: ignore
|
||||||
data={**db_data}, # type: ignore
|
data={**db_data}, # type: ignore
|
||||||
)
|
)
|
||||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
print_verbose(
|
||||||
|
"\033[91m"
|
||||||
|
+ f"DB Token Table update succeeded {response}"
|
||||||
|
+ "\033[0m"
|
||||||
|
)
|
||||||
return {"token": token, "data": db_data}
|
return {"token": token, "data": db_data}
|
||||||
elif user_id is not None:
|
elif user_id is not None:
|
||||||
"""
|
"""
|
||||||
|
@ -540,6 +544,23 @@ class PrismaClient:
|
||||||
where={"user_id": user_id}, # type: ignore
|
where={"user_id": user_id}, # type: ignore
|
||||||
data={**db_data}, # type: ignore
|
data={**db_data}, # type: ignore
|
||||||
)
|
)
|
||||||
|
if update_user_row is None:
|
||||||
|
# if the provided user does not exist, STILL Track this!
|
||||||
|
# make a new user with {"user_id": user_id, "spend": data['spend']}
|
||||||
|
|
||||||
|
db_data["user_id"] = user_id
|
||||||
|
update_user_row = await self.db.litellm_usertable.upsert(
|
||||||
|
where={"user_id": user_id}, # type: ignore
|
||||||
|
data={
|
||||||
|
"create": {**db_data}, # type: ignore
|
||||||
|
"update": {}, # don't do anything if it already exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
"\033[91m"
|
||||||
|
+ f"DB User Table - update succeeded {update_user_row}"
|
||||||
|
+ "\033[0m"
|
||||||
|
)
|
||||||
return {"user_id": user_id, "data": db_data}
|
return {"user_id": user_id, "data": db_data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
|
|
@ -3305,6 +3305,8 @@ def get_optional_params(
|
||||||
unsupported_params = {}
|
unsupported_params = {}
|
||||||
for k in non_default_params.keys():
|
for k in non_default_params.keys():
|
||||||
if k not in supported_params:
|
if k not in supported_params:
|
||||||
|
if k == "user":
|
||||||
|
continue
|
||||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||||
continue # skip this param
|
continue # skip this param
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue