mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix update_database helper on db_spend_update_writer
This commit is contained in:
parent
bcd49204f6
commit
3e16a51ca6
4 changed files with 73 additions and 126 deletions
|
@ -25,6 +25,7 @@ from litellm.proxy._types import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||||
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
||||||
|
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
@ -48,10 +49,11 @@ class DBSpendUpdateWriter:
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
||||||
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||||
|
self.spend_update_queue = SpendUpdateQueue()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_database(
|
async def update_database(
|
||||||
# LiteLLM management object fields
|
# LiteLLM management object fields
|
||||||
|
self,
|
||||||
token: Optional[str],
|
token: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
|
@ -84,7 +86,7 @@ class DBSpendUpdateWriter:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_user_db(
|
self._update_user_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
|
@ -94,14 +96,14 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_key_db(
|
self._update_key_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
hashed_token=hashed_token,
|
hashed_token=hashed_token,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_team_db(
|
self._update_team_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -109,7 +111,7 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_org_db(
|
self._update_org_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
|
@ -135,56 +137,8 @@ class DBSpendUpdateWriter:
|
||||||
f"Error updating Prisma database: {traceback.format_exc()}"
|
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_transaction_list(
|
|
||||||
response_cost: Optional[float],
|
|
||||||
entity_id: Optional[str],
|
|
||||||
transaction_list: dict,
|
|
||||||
entity_type: Litellm_EntityType,
|
|
||||||
debug_msg: Optional[str] = None,
|
|
||||||
prisma_client: Optional[PrismaClient] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Common helper method to update a transaction list for an entity
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_cost: The cost to add
|
|
||||||
entity_id: The ID of the entity to update
|
|
||||||
transaction_list: The transaction list dictionary to update
|
|
||||||
entity_type: The type of entity (from EntityType enum)
|
|
||||||
debug_msg: Optional custom debug message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if update happened, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if debug_msg:
|
|
||||||
verbose_proxy_logger.debug(debug_msg)
|
|
||||||
else:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}."
|
|
||||||
)
|
|
||||||
if prisma_client is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if entity_id is None:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
transaction_list[entity_id] = response_cost + transaction_list.get(
|
|
||||||
entity_id, 0
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_key_db(
|
async def _update_key_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
hashed_token: Optional[str],
|
hashed_token: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -193,13 +147,12 @@ class DBSpendUpdateWriter:
|
||||||
if hashed_token is None or prisma_client is None:
|
if hashed_token is None or prisma_client is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=hashed_token,
|
"entity_type": Litellm_EntityType.KEY.value,
|
||||||
transaction_list=prisma_client.key_list_transactions,
|
"entity_id": hashed_token,
|
||||||
entity_type=Litellm_EntityType.KEY,
|
"amount": response_cost,
|
||||||
debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.",
|
}
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
|
@ -207,8 +160,8 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_user_db(
|
async def _update_user_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -234,21 +187,21 @@ class DBSpendUpdateWriter:
|
||||||
|
|
||||||
for _id in user_ids:
|
for _id in user_ids:
|
||||||
if _id is not None:
|
if _id is not None:
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=_id,
|
"entity_type": Litellm_EntityType.USER.value,
|
||||||
transaction_list=prisma_client.user_list_transactions,
|
"entity_id": _id,
|
||||||
entity_type=Litellm_EntityType.USER,
|
"amount": response_cost,
|
||||||
prisma_client=prisma_client,
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=end_user_id,
|
"entity_type": Litellm_EntityType.END_USER.value,
|
||||||
transaction_list=prisma_client.end_user_list_transactions,
|
"entity_id": end_user_id,
|
||||||
entity_type=Litellm_EntityType.END_USER,
|
"amount": response_cost,
|
||||||
prisma_client=prisma_client,
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
|
@ -256,8 +209,8 @@ class DBSpendUpdateWriter:
|
||||||
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_team_db(
|
async def _update_team_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
team_id: Optional[str],
|
team_id: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
|
@ -270,12 +223,12 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=team_id,
|
"entity_type": Litellm_EntityType.TEAM.value,
|
||||||
transaction_list=prisma_client.team_list_transactions,
|
"entity_id": team_id,
|
||||||
entity_type=Litellm_EntityType.TEAM,
|
"amount": response_cost,
|
||||||
prisma_client=prisma_client,
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -283,12 +236,12 @@ class DBSpendUpdateWriter:
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
# key is "team_id::<value>::user_id::<value>"
|
# key is "team_id::<value>::user_id::<value>"
|
||||||
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=team_member_key,
|
"entity_type": Litellm_EntityType.TEAM_MEMBER.value,
|
||||||
transaction_list=prisma_client.team_member_list_transactions,
|
"entity_id": team_member_key,
|
||||||
entity_type=Litellm_EntityType.TEAM_MEMBER,
|
"amount": response_cost,
|
||||||
prisma_client=prisma_client,
|
}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -298,8 +251,8 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_org_db(
|
async def _update_org_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
org_id: Optional[str],
|
org_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -311,12 +264,12 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update={
|
||||||
entity_id=org_id,
|
"entity_type": Litellm_EntityType.ORGANIZATION.value,
|
||||||
transaction_list=prisma_client.org_list_transactions,
|
"entity_id": org_id,
|
||||||
entity_type=Litellm_EntityType.ORGANIZATION,
|
"amount": response_cost,
|
||||||
prisma_client=prisma_client,
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
|
@ -435,7 +388,7 @@ class DBSpendUpdateWriter:
|
||||||
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
|
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
|
||||||
"""
|
"""
|
||||||
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||||
prisma_client=prisma_client,
|
spend_update_queue=self.spend_update_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only commit from redis to db if this pod is the leader
|
# Only commit from redis to db if this pod is the leader
|
||||||
|
@ -447,7 +400,7 @@ class DBSpendUpdateWriter:
|
||||||
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
|
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
|
||||||
)
|
)
|
||||||
if db_spend_update_transactions is not None:
|
if db_spend_update_transactions is not None:
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
await self._commit_spend_updates_to_db(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
@ -471,23 +424,26 @@ class DBSpendUpdateWriter:
|
||||||
|
|
||||||
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
|
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
|
||||||
"""
|
"""
|
||||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
aggregated_updates = (
|
||||||
user_list_transactions=prisma_client.user_list_transactions,
|
await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type()
|
||||||
end_user_list_transactions=prisma_client.end_user_list_transactions,
|
|
||||||
key_list_transactions=prisma_client.key_list_transactions,
|
|
||||||
team_list_transactions=prisma_client.team_list_transactions,
|
|
||||||
team_member_list_transactions=prisma_client.team_member_list_transactions,
|
|
||||||
org_list_transactions=prisma_client.org_list_transactions,
|
|
||||||
)
|
)
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||||
|
user_list_transactions=aggregated_updates.get("user", {}),
|
||||||
|
end_user_list_transactions=aggregated_updates.get("end_user", {}),
|
||||||
|
key_list_transactions=aggregated_updates.get("key", {}),
|
||||||
|
team_list_transactions=aggregated_updates.get("team", {}),
|
||||||
|
team_member_list_transactions=aggregated_updates.get("team_member", {}),
|
||||||
|
org_list_transactions=aggregated_updates.get("organization", {}),
|
||||||
|
)
|
||||||
|
await self._commit_spend_updates_to_db(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
db_spend_update_transactions=db_spend_update_transactions,
|
db_spend_update_transactions=db_spend_update_transactions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _commit_spend_updates_to_db( # noqa: PLR0915
|
async def _commit_spend_updates_to_db( # noqa: PLR0915
|
||||||
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
n_retry_times: int,
|
n_retry_times: int,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
@ -526,9 +482,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"user_id": user_id},
|
where={"user_id": user_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.user_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -583,9 +536,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"token": token},
|
where={"token": token},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.key_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -632,9 +582,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"team_id": team_id},
|
where={"team_id": team_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.team_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -684,9 +631,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"team_id": team_id, "user_id": user_id},
|
where={"team_id": team_id, "user_id": user_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.team_member_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -725,9 +669,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"organization_id": org_id},
|
where={"organization_id": org_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.org_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -33,7 +33,6 @@ class RedisUpdateBuffer:
|
||||||
redis_cache: Optional[RedisCache] = None,
|
redis_cache: Optional[RedisCache] = None,
|
||||||
):
|
):
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
self.spend_update_queue = SpendUpdateQueue()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_commit_spend_updates_to_redis() -> bool:
|
def _should_commit_spend_updates_to_redis() -> bool:
|
||||||
|
@ -56,6 +55,7 @@ class RedisUpdateBuffer:
|
||||||
|
|
||||||
async def store_in_memory_spend_updates_in_redis(
|
async def store_in_memory_spend_updates_in_redis(
|
||||||
self,
|
self,
|
||||||
|
spend_update_queue: SpendUpdateQueue,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stores the in-memory spend updates to Redis
|
Stores the in-memory spend updates to Redis
|
||||||
|
@ -81,9 +81,9 @@ class RedisUpdateBuffer:
|
||||||
return
|
return
|
||||||
|
|
||||||
aggregated_updates = (
|
aggregated_updates = (
|
||||||
await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type()
|
await spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type()
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: ", aggregated_updates)
|
verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: %s", aggregated_updates)
|
||||||
|
|
||||||
db_spend_update_transactions: DBSpendUpdateTransactions = (
|
db_spend_update_transactions: DBSpendUpdateTransactions = (
|
||||||
DBSpendUpdateTransactions(
|
DBSpendUpdateTransactions(
|
||||||
|
@ -92,7 +92,7 @@ class RedisUpdateBuffer:
|
||||||
key_list_transactions=aggregated_updates.get("key", {}),
|
key_list_transactions=aggregated_updates.get("key", {}),
|
||||||
team_list_transactions=aggregated_updates.get("team", {}),
|
team_list_transactions=aggregated_updates.get("team", {}),
|
||||||
team_member_list_transactions=aggregated_updates.get("team_member", {}),
|
team_member_list_transactions=aggregated_updates.get("team_member", {}),
|
||||||
org_list_transactions=aggregated_updates.get("org", {}),
|
org_list_transactions=aggregated_updates.get("organization", {}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
else:
|
else:
|
||||||
|
@ -21,6 +23,7 @@ class SpendUpdateQueue:
|
||||||
|
|
||||||
async def add_update(self, update: Dict[str, Any]) -> None:
|
async def add_update(self, update: Dict[str, Any]) -> None:
|
||||||
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
|
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
|
||||||
|
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||||
await self.update_queue.put(update)
|
await self.update_queue.put(update)
|
||||||
|
|
||||||
async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]:
|
async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]:
|
||||||
|
@ -35,6 +38,7 @@ class SpendUpdateQueue:
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||||
|
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||||
return self.aggregate_updates_by_entity_type(updates)
|
return self.aggregate_updates_by_entity_type(updates)
|
||||||
|
|
||||||
def aggregate_updates_by_entity_type(
|
def aggregate_updates_by_entity_type(
|
||||||
|
|
|
@ -37,6 +37,8 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
_metadata = dict(
|
_metadata = dict(
|
||||||
StandardLoggingUserAPIKeyMetadata(
|
StandardLoggingUserAPIKeyMetadata(
|
||||||
user_api_key_hash=user_api_key_dict.api_key,
|
user_api_key_hash=user_api_key_dict.api_key,
|
||||||
|
@ -66,7 +68,7 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
request_data.get("proxy_server_request") or {}
|
request_data.get("proxy_server_request") or {}
|
||||||
)
|
)
|
||||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||||
await DBSpendUpdateWriter.update_database(
|
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||||
token=user_api_key_dict.api_key,
|
token=user_api_key_dict.api_key,
|
||||||
response_cost=0.0,
|
response_cost=0.0,
|
||||||
user_id=user_api_key_dict.user_id,
|
user_id=user_api_key_dict.user_id,
|
||||||
|
@ -136,7 +138,7 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
):
|
):
|
||||||
## UPDATE DATABASE
|
## UPDATE DATABASE
|
||||||
await DBSpendUpdateWriter.update_database(
|
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||||
token=user_api_key,
|
token=user_api_key,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue