refactor file structure

This commit is contained in:
Ishaan Jaff 2025-04-01 18:30:48 -07:00
parent 290e837515
commit 8dc792139e
9 changed files with 135 additions and 108 deletions

View file

@ -25,9 +25,12 @@ from litellm.proxy._types import (
SpendLogsPayload,
SpendUpdateQueueItem,
)
from litellm.proxy.db.pod_lock_manager import PodLockManager
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
from litellm.proxy.db.spend_update_queue import DailySpendUpdateQueue, SpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging

View file

@ -0,0 +1,22 @@
import asyncio
from litellm._logging import verbose_proxy_logger
class BaseUpdateQueue:
"""Base class for spend update queues with common functionality"""
def __init__(self):
self.update_queue = asyncio.Queue()
async def add_update(self, update):
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
async def flush_all_updates_from_in_memory_queue(self):
"""Get all updates from the queue."""
updates = []
while not self.update_queue.empty():
updates.append(await self.update_queue.get())
return updates

View file

@ -0,0 +1,95 @@
import asyncio
from typing import Dict, List
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import DailyUserSpendTransaction
from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue
class DailySpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for daily spend updates that should be committed to the database
To add a new daily spend update transaction, use the following format:
daily_spend_update_queue.add_update({
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
}
})
Queue contains a list of daily spend update transactions
eg
queue = [
{
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
},
{
"user2_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
}
]
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[
Dict[str, DailyUserSpendTransaction]
] = asyncio.Queue()
async def flush_and_get_aggregated_daily_spend_update_transactions(
self,
) -> Dict[str, DailyUserSpendTransaction]:
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key."""
updates = await self.flush_all_updates_from_in_memory_queue()
aggregated_daily_spend_update_transactions = (
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
updates
)
)
verbose_proxy_logger.debug(
"Aggregated daily spend update transactions: %s",
aggregated_daily_spend_update_transactions,
)
return aggregated_daily_spend_update_transactions
@staticmethod
def get_aggregated_daily_spend_update_transactions(
updates: List[Dict[str, DailyUserSpendTransaction]]
) -> Dict[str, DailyUserSpendTransaction]:
"""Aggregate updates by daily_transaction_key."""
aggregated_daily_spend_update_transactions: Dict[
str, DailyUserSpendTransaction
] = {}
for _update in updates:
for _key, payload in _update.items():
if _key in aggregated_daily_spend_update_transactions:
daily_transaction = aggregated_daily_spend_update_transactions[_key]
daily_transaction["spend"] += payload["spend"]
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
daily_transaction["completion_tokens"] += payload[
"completion_tokens"
]
daily_transaction["api_requests"] += payload["api_requests"]
daily_transaction["successful_requests"] += payload[
"successful_requests"
]
daily_transaction["failed_requests"] += payload["failed_requests"]
else:
aggregated_daily_spend_update_transactions[_key] = payload
return aggregated_daily_spend_update_transactions

View file

@ -16,7 +16,10 @@ from litellm.constants import (
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import DailyUserSpendTransaction, DBSpendUpdateTransactions
from litellm.proxy.db.spend_update_queue import DailySpendUpdateQueue, SpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:

View file

@ -1,44 +1,24 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List
from typing import List
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue
class SpendUpdateQueue:
class SpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for spend updates that should be committed to the database
"""
def __init__(
self,
):
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
async def add_update(self, update: SpendUpdateQueueItem) -> None:
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
async def flush_all_updates_from_in_memory_queue(
self,
) -> List[SpendUpdateQueueItem]:
"""Get all updates from the queue."""
updates: List[SpendUpdateQueueItem] = []
while not self.update_queue.empty():
updates.append(await self.update_queue.get())
return updates
async def flush_and_get_aggregated_db_spend_update_transactions(
self,
) -> DBSpendUpdateTransactions:
@ -131,79 +111,3 @@ class SpendUpdateQueue:
transactions_dict[entity_id] += response_cost or 0
return db_spend_update_transactions
class DailySpendUpdateQueue:
def __init__(
self,
):
self.update_queue: asyncio.Queue[
Dict[str, DailyUserSpendTransaction]
] = asyncio.Queue()
async def add_update(self, update: Dict[str, DailyUserSpendTransaction]) -> None:
"""Enqueue an update. Each update might be a dict like
{
"user_date_api_key_model_custom_llm_provider": {
"spend": 1.2,
"prompt_tokens": 1000,
"completion_tokens": 1000,
"api_requests": 1000,
"successful_requests": 1000,
"failed_requests": 1000,
}
}
."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
async def flush_all_updates_from_in_memory_queue(
self,
) -> List[Dict[str, DailyUserSpendTransaction]]:
"""Get all updates from the queue."""
updates: List[Dict[str, DailyUserSpendTransaction]] = []
while not self.update_queue.empty():
updates.append(await self.update_queue.get())
return updates
async def flush_and_get_aggregated_daily_spend_update_transactions(
self,
) -> Dict[str, DailyUserSpendTransaction]:
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key."""
updates = await self.flush_all_updates_from_in_memory_queue()
aggregated_daily_spend_update_transactions = (
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
updates
)
)
verbose_proxy_logger.debug(
"Aggregated daily spend update transactions: %s",
aggregated_daily_spend_update_transactions,
)
return aggregated_daily_spend_update_transactions
@staticmethod
def get_aggregated_daily_spend_update_transactions(
updates: List[Dict[str, DailyUserSpendTransaction]]
) -> Dict[str, DailyUserSpendTransaction]:
"""Aggregate updates by daily_transaction_key."""
aggregated_daily_spend_update_transactions: Dict[
str, DailyUserSpendTransaction
] = {}
for _update in updates:
for _key, payload in _update.items():
if _key in aggregated_daily_spend_update_transactions:
daily_transaction = aggregated_daily_spend_update_transactions[_key]
daily_transaction["spend"] += payload["spend"]
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
daily_transaction["completion_tokens"] += payload[
"completion_tokens"
]
daily_transaction["api_requests"] += payload["api_requests"]
daily_transaction["successful_requests"] += payload[
"successful_requests"
]
daily_transaction["failed_requests"] += payload["failed_requests"]
else:
aggregated_daily_spend_update_transactions[_key] = payload
return aggregated_daily_spend_update_transactions

View file

@ -12,7 +12,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.pod_lock_manager import PodLockManager
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
# Mock Prisma client class

View file

@ -7,7 +7,7 @@ import pytest
from fastapi.testclient import TestClient
from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
sys.path.insert(
0, os.path.abspath("../../..")

View file

@ -23,7 +23,7 @@ import asyncio
import logging
import pytest
from litellm.proxy.db.pod_lock_manager import PodLockManager
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.management_endpoints.internal_user_endpoints import (