Merge branch 'main' into litellm_moderations_improvements

This commit is contained in:
Krish Dholakia 2024-02-15 23:08:25 -08:00 committed by GitHub
commit 999fab82f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 683 additions and 72 deletions

View file

@ -819,6 +819,7 @@ async def _PROXY_track_cost_callback(
user_id = user_id or kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get(
@ -842,6 +843,7 @@ async def _PROXY_track_cost_callback(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
@ -879,6 +881,7 @@ async def update_database(
token,
response_cost,
user_id=None,
team_id=None,
kwargs=None,
completion_response=None,
start_time=None,
@ -886,7 +889,7 @@ async def update_database(
):
try:
verbose_proxy_logger.info(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}"
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
)
### [TODO] STEP 1: GET KEY + USER SPEND ### (key, user)
@ -1039,8 +1042,69 @@ async def update_database(
except Exception as e:
verbose_proxy_logger.info(f"Update Spend Logs DB failed to execute")
### UPDATE KEY SPEND ###
async def _update_team_db():
try:
verbose_proxy_logger.debug(
f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}."
)
if team_id is None:
verbose_proxy_logger.debug(
"track_cost_callback: team_id is None. Not tracking spend for team"
)
return
if prisma_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(
team_id=team_id, table_name="team"
)
verbose_proxy_logger.debug(
f"_update_team_db: existing spend: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(
team_id=team_id, data={"spend": new_spend}, table_name="team"
)
elif custom_db_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await custom_db_client.get_data(
key=token, table_name="key"
)
verbose_proxy_logger.debug(
f"_update_key_db existing spend: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await custom_db_client.update_data(
key=token, value={"spend": new_spend}, table_name="key"
)
valid_token = user_api_key_cache.get_cache(key=token)
if valid_token is not None:
valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token)
except Exception as e:
verbose_proxy_logger.info(f"Update Team DB failed to execute")
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_update_team_db())
asyncio.create_task(_insert_spend_log_to_db())
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
except Exception as e:
@ -2143,6 +2207,9 @@ async def completion(
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
_headers = dict(request.headers)
_headers.pop(
"authorization", None
@ -2306,6 +2373,9 @@ async def chat_completion(
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
@ -2527,6 +2597,9 @@ async def embeddings(
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -2698,6 +2771,9 @@ async def image_generation(
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -2853,6 +2929,9 @@ async def moderations(
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
@ -4208,6 +4287,9 @@ async def async_queue_request(
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["endpoint"] = str(request.url)
global user_temperature, user_request_timeout, user_max_tokens, user_api_base