diff --git a/.circleci/config.yml b/.circleci/config.yml
index e1488a9083..d2a6aafef7 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1450,7 +1450,7 @@ jobs:
command: |
pwd
ls
- python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
+ python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/spend_tracking_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
no_output_timeout: 120m
# Store test results
@@ -1743,6 +1743,96 @@ jobs:
# Store test results
- store_test_results:
path: test-results
+ proxy_spend_accuracy_tests:
+ machine:
+ image: ubuntu-2204:2023.10.1
+ resource_class: xlarge
+ working_directory: ~/project
+ steps:
+ - checkout
+ - setup_google_dns
+ - run:
+ name: Install Docker CLI (In case it's not already installed)
+ command: |
+ sudo apt-get update
+ sudo apt-get install -y docker-ce docker-ce-cli containerd.io
+ - run:
+ name: Install Python 3.9
+ command: |
+ curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
+ bash miniconda.sh -b -p $HOME/miniconda
+ export PATH="$HOME/miniconda/bin:$PATH"
+ conda init bash
+ source ~/.bashrc
+ conda create -n myenv python=3.9 -y
+ conda activate myenv
+ python --version
+ - run:
+ name: Install Dependencies
+ command: |
+ pip install "pytest==7.3.1"
+ pip install "pytest-asyncio==0.21.1"
+ pip install aiohttp
+ python -m pip install --upgrade pip
+ python -m pip install -r requirements.txt
+ - run:
+ name: Build Docker image
+ command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
+ - run:
+ name: Run Docker container
+ # intentionally give bad redis credentials here
+ # the OTEL test - should get this as a trace
+ command: |
+ docker run -d \
+ -p 4000:4000 \
+ -e DATABASE_URL=$PROXY_DATABASE_URL \
+ -e REDIS_HOST=$REDIS_HOST \
+ -e REDIS_PASSWORD=$REDIS_PASSWORD \
+ -e REDIS_PORT=$REDIS_PORT \
+ -e LITELLM_MASTER_KEY="sk-1234" \
+ -e OPENAI_API_KEY=$OPENAI_API_KEY \
+ -e LITELLM_LICENSE=$LITELLM_LICENSE \
+ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
+ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
+ -e USE_DDTRACE=True \
+ -e DD_API_KEY=$DD_API_KEY \
+ -e DD_SITE=$DD_SITE \
+ -e AWS_REGION_NAME=$AWS_REGION_NAME \
+ --name my-app \
+ -v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \
+ my-app:latest \
+ --config /app/config.yaml \
+ --port 4000 \
+ --detailed_debug \
+ - run:
+ name: Install curl and dockerize
+ command: |
+ sudo apt-get update
+ sudo apt-get install -y curl
+ sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
+ sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
+ sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
+ - run:
+ name: Start outputting logs
+ command: docker logs -f my-app
+ background: true
+ - run:
+ name: Wait for app to be ready
+ command: dockerize -wait http://localhost:4000 -timeout 5m
+ - run:
+ name: Run tests
+ command: |
+ pwd
+ ls
+ python -m pytest -vv tests/spend_tracking_tests -x --junitxml=test-results/junit.xml --durations=5
+ no_output_timeout:
+ 120m
+ # Clean up first container
+ - run:
+ name: Stop and remove first container
+ command: |
+ docker stop my-app
+ docker rm my-app
proxy_multi_instance_tests:
machine:
@@ -2553,6 +2643,12 @@ workflows:
only:
- main
- /litellm_.*/
+ - proxy_spend_accuracy_tests:
+ filters:
+ branches:
+ only:
+ - main
+ - /litellm_.*/
- proxy_multi_instance_tests:
filters:
branches:
@@ -2714,6 +2810,7 @@ workflows:
- installing_litellm_on_python
- installing_litellm_on_python_3_13
- proxy_logging_guardrails_model_info_tests
+ - proxy_spend_accuracy_tests
- proxy_multi_instance_tests
- proxy_store_model_in_db_tests
- proxy_build_from_pip_tests
diff --git a/cookbook/misc/dev_release.txt b/cookbook/misc/dev_release.txt
index 717a6da546..bd40f89e6f 100644
--- a/cookbook/misc/dev_release.txt
+++ b/cookbook/misc/dev_release.txt
@@ -1,2 +1,11 @@
python3 -m build
-twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
\ No newline at end of file
+twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
+
+
+Note: You might need to make a MANIFEST.ini file on root for build process incase it fails
+
+Place this in MANIFEST.ini
+recursive-exclude venv *
+recursive-exclude myenv *
+recursive-exclude py313_env *
+recursive-exclude **/.venv *
diff --git a/docs/my-website/docs/enterprise.md b/docs/my-website/docs/enterprise.md
index 5aeeb710ff..706ca33714 100644
--- a/docs/my-website/docs/enterprise.md
+++ b/docs/my-website/docs/enterprise.md
@@ -1,3 +1,5 @@
+import Image from '@theme/IdealImage';
+
# Enterprise
For companies that need SSO, user management and professional support for LiteLLM Proxy
@@ -7,6 +9,8 @@ Get free 7-day trial key [here](https://www.litellm.ai/#trial)
Includes all enterprise features.
+
+
[**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs)
diff --git a/docs/my-website/img/enterprise_vs_oss.png b/docs/my-website/img/enterprise_vs_oss.png
new file mode 100644
index 0000000000..f2b58fbc14
Binary files /dev/null and b/docs/my-website/img/enterprise_vs_oss.png differ
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index 9536442475..aacc9f525b 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -2751,3 +2751,9 @@ class DBSpendUpdateTransactions(TypedDict):
team_list_transactions: Optional[Dict[str, float]]
team_member_list_transactions: Optional[Dict[str, float]]
org_list_transactions: Optional[Dict[str, float]]
+
+
+class SpendUpdateQueueItem(TypedDict, total=False):
+ entity_type: Litellm_EntityType
+ entity_id: str
+ response_cost: Optional[float]
diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py
index f46b03b57a..5bf255feae 100644
--- a/litellm/proxy/db/db_spend_update_writer.py
+++ b/litellm/proxy/db/db_spend_update_writer.py
@@ -22,9 +22,11 @@ from litellm.proxy._types import (
Litellm_EntityType,
LiteLLM_UserTable,
SpendLogsPayload,
+ SpendUpdateQueueItem,
)
from litellm.proxy.db.pod_lock_manager import PodLockManager
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
+from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging
@@ -48,10 +50,11 @@ class DBSpendUpdateWriter:
self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
+ self.spend_update_queue = SpendUpdateQueue()
- @staticmethod
async def update_database(
# LiteLLM management object fields
+ self,
token: Optional[str],
user_id: Optional[str],
end_user_id: Optional[str],
@@ -84,7 +87,7 @@ class DBSpendUpdateWriter:
hashed_token = token
asyncio.create_task(
- DBSpendUpdateWriter._update_user_db(
+ self._update_user_db(
response_cost=response_cost,
user_id=user_id,
prisma_client=prisma_client,
@@ -94,14 +97,14 @@ class DBSpendUpdateWriter:
)
)
asyncio.create_task(
- DBSpendUpdateWriter._update_key_db(
+ self._update_key_db(
response_cost=response_cost,
hashed_token=hashed_token,
prisma_client=prisma_client,
)
)
asyncio.create_task(
- DBSpendUpdateWriter._update_team_db(
+ self._update_team_db(
response_cost=response_cost,
team_id=team_id,
user_id=user_id,
@@ -109,7 +112,7 @@ class DBSpendUpdateWriter:
)
)
asyncio.create_task(
- DBSpendUpdateWriter._update_org_db(
+ self._update_org_db(
response_cost=response_cost,
org_id=org_id,
prisma_client=prisma_client,
@@ -135,56 +138,8 @@ class DBSpendUpdateWriter:
f"Error updating Prisma database: {traceback.format_exc()}"
)
- @staticmethod
- async def _update_transaction_list(
- response_cost: Optional[float],
- entity_id: Optional[str],
- transaction_list: dict,
- entity_type: Litellm_EntityType,
- debug_msg: Optional[str] = None,
- prisma_client: Optional[PrismaClient] = None,
- ) -> bool:
- """
- Common helper method to update a transaction list for an entity
-
- Args:
- response_cost: The cost to add
- entity_id: The ID of the entity to update
- transaction_list: The transaction list dictionary to update
- entity_type: The type of entity (from EntityType enum)
- debug_msg: Optional custom debug message
-
- Returns:
- bool: True if update happened, False otherwise
- """
- try:
- if debug_msg:
- verbose_proxy_logger.debug(debug_msg)
- else:
- verbose_proxy_logger.debug(
- f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}."
- )
- if prisma_client is None:
- return False
-
- if entity_id is None:
- verbose_proxy_logger.debug(
- f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}"
- )
- return False
- transaction_list[entity_id] = response_cost + transaction_list.get(
- entity_id, 0
- )
- return True
-
- except Exception as e:
- verbose_proxy_logger.info(
- f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}"
- )
- raise e
-
- @staticmethod
async def _update_key_db(
+ self,
response_cost: Optional[float],
hashed_token: Optional[str],
prisma_client: Optional[PrismaClient],
@@ -193,13 +148,12 @@ class DBSpendUpdateWriter:
if hashed_token is None or prisma_client is None:
return
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=hashed_token,
- transaction_list=prisma_client.key_list_transactions,
- entity_type=Litellm_EntityType.KEY,
- debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.",
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.KEY,
+ entity_id=hashed_token,
+ response_cost=response_cost,
+ )
)
except Exception as e:
verbose_proxy_logger.exception(
@@ -207,8 +161,8 @@ class DBSpendUpdateWriter:
)
raise e
- @staticmethod
async def _update_user_db(
+ self,
response_cost: Optional[float],
user_id: Optional[str],
prisma_client: Optional[PrismaClient],
@@ -234,21 +188,21 @@ class DBSpendUpdateWriter:
for _id in user_ids:
if _id is not None:
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=_id,
- transaction_list=prisma_client.user_list_transactions,
- entity_type=Litellm_EntityType.USER,
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.USER,
+ entity_id=_id,
+ response_cost=response_cost,
+ )
)
if end_user_id is not None:
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=end_user_id,
- transaction_list=prisma_client.end_user_list_transactions,
- entity_type=Litellm_EntityType.END_USER,
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.END_USER,
+ entity_id=end_user_id,
+ response_cost=response_cost,
+ )
)
except Exception as e:
verbose_proxy_logger.info(
@@ -256,8 +210,8 @@ class DBSpendUpdateWriter:
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
)
- @staticmethod
async def _update_team_db(
+ self,
response_cost: Optional[float],
team_id: Optional[str],
user_id: Optional[str],
@@ -270,12 +224,12 @@ class DBSpendUpdateWriter:
)
return
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=team_id,
- transaction_list=prisma_client.team_list_transactions,
- entity_type=Litellm_EntityType.TEAM,
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.TEAM,
+ entity_id=team_id,
+ response_cost=response_cost,
+ )
)
try:
@@ -283,12 +237,12 @@ class DBSpendUpdateWriter:
if user_id is not None:
# key is "team_id::::user_id::"
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=team_member_key,
- transaction_list=prisma_client.team_member_list_transactions,
- entity_type=Litellm_EntityType.TEAM_MEMBER,
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.TEAM_MEMBER,
+ entity_id=team_member_key,
+ response_cost=response_cost,
+ )
)
except Exception:
pass
@@ -298,8 +252,8 @@ class DBSpendUpdateWriter:
)
raise e
- @staticmethod
async def _update_org_db(
+ self,
response_cost: Optional[float],
org_id: Optional[str],
prisma_client: Optional[PrismaClient],
@@ -311,12 +265,12 @@ class DBSpendUpdateWriter:
)
return
- await DBSpendUpdateWriter._update_transaction_list(
- response_cost=response_cost,
- entity_id=org_id,
- transaction_list=prisma_client.org_list_transactions,
- entity_type=Litellm_EntityType.ORGANIZATION,
- prisma_client=prisma_client,
+ await self.spend_update_queue.add_update(
+ update=SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.ORGANIZATION,
+ entity_id=org_id,
+ response_cost=response_cost,
+ )
)
except Exception as e:
verbose_proxy_logger.info(
@@ -435,7 +389,7 @@ class DBSpendUpdateWriter:
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
"""
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
- prisma_client=prisma_client,
+ spend_update_queue=self.spend_update_queue,
)
# Only commit from redis to db if this pod is the leader
@@ -447,7 +401,7 @@ class DBSpendUpdateWriter:
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
)
if db_spend_update_transactions is not None:
- await DBSpendUpdateWriter._commit_spend_updates_to_db(
+ await self._commit_spend_updates_to_db(
prisma_client=prisma_client,
n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj,
@@ -471,23 +425,18 @@ class DBSpendUpdateWriter:
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
"""
- db_spend_update_transactions = DBSpendUpdateTransactions(
- user_list_transactions=prisma_client.user_list_transactions,
- end_user_list_transactions=prisma_client.end_user_list_transactions,
- key_list_transactions=prisma_client.key_list_transactions,
- team_list_transactions=prisma_client.team_list_transactions,
- team_member_list_transactions=prisma_client.team_member_list_transactions,
- org_list_transactions=prisma_client.org_list_transactions,
+ db_spend_update_transactions = (
+ await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
)
- await DBSpendUpdateWriter._commit_spend_updates_to_db(
+ await self._commit_spend_updates_to_db(
prisma_client=prisma_client,
n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj,
db_spend_update_transactions=db_spend_update_transactions,
)
- @staticmethod
async def _commit_spend_updates_to_db( # noqa: PLR0915
+ self,
prisma_client: PrismaClient,
n_retry_times: int,
proxy_logging_obj: ProxyLogging,
@@ -526,9 +475,6 @@ class DBSpendUpdateWriter:
where={"user_id": user_id},
data={"spend": {"increment": response_cost}},
)
- prisma_client.user_list_transactions = (
- {}
- ) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if (
@@ -561,6 +507,7 @@ class DBSpendUpdateWriter:
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
+ end_user_list_transactions=end_user_list_transactions,
)
### UPDATE KEY TABLE ###
key_list_transactions = db_spend_update_transactions["key_list_transactions"]
@@ -583,9 +530,6 @@ class DBSpendUpdateWriter:
where={"token": token},
data={"spend": {"increment": response_cost}},
)
- prisma_client.key_list_transactions = (
- {}
- ) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if (
@@ -632,9 +576,6 @@ class DBSpendUpdateWriter:
where={"team_id": team_id},
data={"spend": {"increment": response_cost}},
)
- prisma_client.team_list_transactions = (
- {}
- ) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if (
@@ -684,9 +625,6 @@ class DBSpendUpdateWriter:
where={"team_id": team_id, "user_id": user_id},
data={"spend": {"increment": response_cost}},
)
- prisma_client.team_member_list_transactions = (
- {}
- ) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if (
@@ -725,9 +663,6 @@ class DBSpendUpdateWriter:
where={"organization_id": org_id},
data={"spend": {"increment": response_cost}},
)
- prisma_client.org_list_transactions = (
- {}
- ) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if (
diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py
index f98fc9300f..1a3fd3d42d 100644
--- a/litellm/proxy/db/redis_update_buffer.py
+++ b/litellm/proxy/db/redis_update_buffer.py
@@ -12,6 +12,7 @@ from litellm.caching import RedisCache
from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import DBSpendUpdateTransactions
+from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
from litellm.secret_managers.main import str_to_bool
if TYPE_CHECKING:
@@ -54,7 +55,7 @@ class RedisUpdateBuffer:
async def store_in_memory_spend_updates_in_redis(
self,
- prisma_client: PrismaClient,
+ spend_update_queue: SpendUpdateQueue,
):
"""
Stores the in-memory spend updates to Redis
@@ -78,13 +79,12 @@ class RedisUpdateBuffer:
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
)
return
- db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions(
- user_list_transactions=prisma_client.user_list_transactions,
- end_user_list_transactions=prisma_client.end_user_list_transactions,
- key_list_transactions=prisma_client.key_list_transactions,
- team_list_transactions=prisma_client.team_list_transactions,
- team_member_list_transactions=prisma_client.team_member_list_transactions,
- org_list_transactions=prisma_client.org_list_transactions,
+
+ db_spend_update_transactions = (
+ await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+ verbose_proxy_logger.debug(
+ "ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
# only store in redis if there are any updates to commit
@@ -100,9 +100,6 @@ class RedisUpdateBuffer:
values=list_of_transactions,
)
- # clear the in-memory spend updates
- RedisUpdateBuffer._clear_all_in_memory_spend_updates(prisma_client)
-
@staticmethod
def _number_of_transactions_to_store_in_redis(
db_spend_update_transactions: DBSpendUpdateTransactions,
@@ -116,20 +113,6 @@ class RedisUpdateBuffer:
num_transactions += len(v)
return num_transactions
- @staticmethod
- def _clear_all_in_memory_spend_updates(
- prisma_client: PrismaClient,
- ):
- """
- Clears all in-memory spend updates
- """
- prisma_client.user_list_transactions = {}
- prisma_client.end_user_list_transactions = {}
- prisma_client.key_list_transactions = {}
- prisma_client.team_list_transactions = {}
- prisma_client.team_member_list_transactions = {}
- prisma_client.org_list_transactions = {}
-
@staticmethod
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
"""
diff --git a/litellm/proxy/db/spend_update_queue.py b/litellm/proxy/db/spend_update_queue.py
new file mode 100644
index 0000000000..28e05246fa
--- /dev/null
+++ b/litellm/proxy/db/spend_update_queue.py
@@ -0,0 +1,132 @@
+import asyncio
+from typing import TYPE_CHECKING, Any, List
+
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+ DBSpendUpdateTransactions,
+ Litellm_EntityType,
+ SpendUpdateQueueItem,
+)
+
+if TYPE_CHECKING:
+ from litellm.proxy.utils import PrismaClient
+else:
+ PrismaClient = Any
+
+
+class SpendUpdateQueue:
+ """
+ In memory buffer for spend updates that should be committed to the database
+ """
+
+ def __init__(
+ self,
+ ):
+ self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
+
+ async def add_update(self, update: SpendUpdateQueueItem) -> None:
+ """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
+ verbose_proxy_logger.debug("Adding update to queue: %s", update)
+ await self.update_queue.put(update)
+
+ async def flush_all_updates_from_in_memory_queue(
+ self,
+ ) -> List[SpendUpdateQueueItem]:
+ """Get all updates from the queue."""
+ updates: List[SpendUpdateQueueItem] = []
+ while not self.update_queue.empty():
+ updates.append(await self.update_queue.get())
+ return updates
+
+ async def flush_and_get_aggregated_db_spend_update_transactions(
+ self,
+ ) -> DBSpendUpdateTransactions:
+ """Flush all updates from the queue and return all updates aggregated by entity type."""
+ updates = await self.flush_all_updates_from_in_memory_queue()
+ verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
+ return self.get_aggregated_db_spend_update_transactions(updates)
+
+ def get_aggregated_db_spend_update_transactions(
+ self, updates: List[SpendUpdateQueueItem]
+ ) -> DBSpendUpdateTransactions:
+ """Aggregate updates by entity type."""
+ # Initialize all transaction lists as empty dicts
+ db_spend_update_transactions = DBSpendUpdateTransactions(
+ user_list_transactions={},
+ end_user_list_transactions={},
+ key_list_transactions={},
+ team_list_transactions={},
+ team_member_list_transactions={},
+ org_list_transactions={},
+ )
+
+ # Map entity types to their corresponding transaction dictionary keys
+ entity_type_to_dict_key = {
+ Litellm_EntityType.USER: "user_list_transactions",
+ Litellm_EntityType.END_USER: "end_user_list_transactions",
+ Litellm_EntityType.KEY: "key_list_transactions",
+ Litellm_EntityType.TEAM: "team_list_transactions",
+ Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
+ Litellm_EntityType.ORGANIZATION: "org_list_transactions",
+ }
+
+ for update in updates:
+ entity_type = update.get("entity_type")
+ entity_id = update.get("entity_id") or ""
+ response_cost = update.get("response_cost") or 0
+
+ if entity_type is None:
+ verbose_proxy_logger.debug(
+ "Skipping update spend for update: %s, because entity_type is None",
+ update,
+ )
+ continue
+
+ dict_key = entity_type_to_dict_key.get(entity_type)
+ if dict_key is None:
+ verbose_proxy_logger.debug(
+ "Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
+ update,
+ )
+ continue # Skip unknown entity types
+
+ # Type-safe access using if/elif statements
+ if dict_key == "user_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "user_list_transactions"
+ ]
+ elif dict_key == "end_user_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "end_user_list_transactions"
+ ]
+ elif dict_key == "key_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "key_list_transactions"
+ ]
+ elif dict_key == "team_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "team_list_transactions"
+ ]
+ elif dict_key == "team_member_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "team_member_list_transactions"
+ ]
+ elif dict_key == "org_list_transactions":
+ transactions_dict = db_spend_update_transactions[
+ "org_list_transactions"
+ ]
+ else:
+ continue
+
+ if transactions_dict is None:
+ transactions_dict = {}
+
+ # type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
+ db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
+
+ if entity_id not in transactions_dict:
+ transactions_dict[entity_id] = 0
+
+ transactions_dict[entity_id] += response_cost or 0
+
+ return db_spend_update_transactions
diff --git a/litellm/proxy/example_config_yaml/spend_tracking_config.yaml b/litellm/proxy/example_config_yaml/spend_tracking_config.yaml
new file mode 100644
index 0000000000..fe8d73d26a
--- /dev/null
+++ b/litellm/proxy/example_config_yaml/spend_tracking_config.yaml
@@ -0,0 +1,15 @@
+model_list:
+ - model_name: fake-openai-endpoint
+ litellm_params:
+ model: openai/fake
+ api_key: fake-key
+ api_base: https://exampleopenaiendpoint-production.up.railway.app/
+
+general_settings:
+ use_redis_transaction_buffer: true
+
+litellm_settings:
+ cache: True
+ cache_params:
+ type: redis
+ supported_call_types: []
\ No newline at end of file
diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py
index 39c1eeace9..9ae1eb4c34 100644
--- a/litellm/proxy/hooks/proxy_track_cost_callback.py
+++ b/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -13,7 +13,6 @@ from litellm.litellm_core_utils.core_helpers import (
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import log_db_metrics
-from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
@@ -37,6 +36,8 @@ class _ProxyDBLogger(CustomLogger):
if _ProxyDBLogger._should_track_errors_in_db() is False:
return
+ from litellm.proxy.proxy_server import proxy_logging_obj
+
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
@@ -66,7 +67,7 @@ class _ProxyDBLogger(CustomLogger):
request_data.get("proxy_server_request") or {}
)
request_data["litellm_params"]["metadata"] = existing_metadata
- await DBSpendUpdateWriter.update_database(
+ await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key_dict.api_key,
response_cost=0.0,
user_id=user_api_key_dict.user_id,
@@ -136,7 +137,7 @@ class _ProxyDBLogger(CustomLogger):
end_user_id=end_user_id,
):
## UPDATE DATABASE
- await DBSpendUpdateWriter.update_database(
+ await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index 9cdd1e7236..17658df903 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -1,9 +1,6 @@
model_list:
- - model_name: gpt-4o
+ - model_name: fake-openai-endpoint
litellm_params:
- model: openai/gpt-4o
- api_key: sk-xxxxxxx
-
-general_settings:
- service_account_settings:
- enforced_params: ["user"] # this means the "user" param is enforced for all requests made through any service account keys
\ No newline at end of file
+ model: openai/fake
+ api_key: fake-key
+ api_base: https://exampleopenaiendpoint-production.up.railway.app/
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index f612c88ccc..67d1882a11 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -1111,12 +1111,6 @@ def jsonify_object(data: dict) -> dict:
class PrismaClient:
- user_list_transactions: dict = {}
- end_user_list_transactions: dict = {}
- key_list_transactions: dict = {}
- team_list_transactions: dict = {}
- team_member_list_transactions: dict = {} # key is ["team_id" + "user_id"]
- org_list_transactions: dict = {}
spend_log_transactions: List = []
daily_user_spend_transactions: Dict[str, DailyUserSpendTransaction] = {}
@@ -2479,7 +2473,10 @@ def _hash_token_if_needed(token: str) -> str:
class ProxyUpdateSpend:
@staticmethod
async def update_end_user_spend(
- n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
+ n_retry_times: int,
+ prisma_client: PrismaClient,
+ proxy_logging_obj: ProxyLogging,
+ end_user_list_transactions: Dict[str, float],
):
for i in range(n_retry_times + 1):
start_time = time.time()
@@ -2491,7 +2488,7 @@ class ProxyUpdateSpend:
for (
end_user_id,
response_cost,
- ) in prisma_client.end_user_list_transactions.items():
+ ) in end_user_list_transactions.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
@@ -2518,10 +2515,6 @@ class ProxyUpdateSpend:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
- finally:
- prisma_client.end_user_list_transactions = (
- {}
- ) # reset the end user list transactions - prevent bad data from causing issues
@staticmethod
async def update_spend_logs(
diff --git a/tests/litellm/proxy/db/test_spend_update_queue.py b/tests/litellm/proxy/db/test_spend_update_queue.py
new file mode 100644
index 0000000000..89d494a070
--- /dev/null
+++ b/tests/litellm/proxy/db/test_spend_update_queue.py
@@ -0,0 +1,152 @@
+import asyncio
+import json
+import os
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
+from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
+
+sys.path.insert(
+ 0, os.path.abspath("../../..")
+) # Adds the parent directory to the system path
+
+
+@pytest.fixture
+def spend_queue():
+ return SpendUpdateQueue()
+
+
+@pytest.mark.asyncio
+async def test_add_update(spend_queue):
+ # Test adding a single update
+ update: SpendUpdateQueueItem = {
+ "entity_type": Litellm_EntityType.USER,
+ "entity_id": "user123",
+ "response_cost": 0.5,
+ }
+ await spend_queue.add_update(update)
+
+ # Verify update was added by checking queue size
+ assert spend_queue.update_queue.qsize() == 1
+
+
+@pytest.mark.asyncio
+async def test_missing_response_cost(spend_queue):
+ # Test with missing response_cost - should default to 0
+ update: SpendUpdateQueueItem = {
+ "entity_type": Litellm_EntityType.USER,
+ "entity_id": "user123",
+ }
+
+ await spend_queue.add_update(update)
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Should have created entry with 0 cost
+ assert aggregated["user_list_transactions"]["user123"] == 0
+
+
+@pytest.mark.asyncio
+async def test_missing_entity_id(spend_queue):
+ # Test with missing entity_id - should default to empty string
+ update: SpendUpdateQueueItem = {
+ "entity_type": Litellm_EntityType.USER,
+ "response_cost": 1.0,
+ }
+
+ await spend_queue.add_update(update)
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Should use empty string as key
+ assert aggregated["user_list_transactions"][""] == 1.0
+
+
+@pytest.mark.asyncio
+async def test_none_values(spend_queue):
+ # Test with None values
+ update: SpendUpdateQueueItem = {
+ "entity_type": Litellm_EntityType.USER,
+ "entity_id": None, # type: ignore
+ "response_cost": None,
+ }
+
+ await spend_queue.add_update(update)
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Should handle None values gracefully
+ assert aggregated["user_list_transactions"][""] == 0
+
+
+@pytest.mark.asyncio
+async def test_multiple_updates_with_missing_fields(spend_queue):
+ # Test multiple updates with various missing fields
+ updates: list[SpendUpdateQueueItem] = [
+ {
+ "entity_type": Litellm_EntityType.USER,
+ "entity_id": "user123",
+ "response_cost": 0.5,
+ },
+ {
+ "entity_type": Litellm_EntityType.USER,
+ "entity_id": "user123", # missing response_cost
+ },
+ {
+ "entity_type": Litellm_EntityType.USER, # missing entity_id
+ "response_cost": 1.5,
+ },
+ ]
+
+ for update in updates:
+ await spend_queue.add_update(update)
+
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Verify aggregation
+ assert (
+ aggregated["user_list_transactions"]["user123"] == 0.5
+ ) # only the first update with valid cost
+ assert (
+ aggregated["user_list_transactions"][""] == 1.5
+ ) # update with missing entity_id
+
+
+@pytest.mark.asyncio
+async def test_unknown_entity_type(spend_queue):
+ # Test with unknown entity type
+ update: SpendUpdateQueueItem = {
+ "entity_type": "UNKNOWN_TYPE", # type: ignore
+ "entity_id": "123",
+ "response_cost": 0.5,
+ }
+
+ await spend_queue.add_update(update)
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Should ignore unknown entity type
+ assert all(len(transactions) == 0 for transactions in aggregated.values())
+
+
+@pytest.mark.asyncio
+async def test_missing_entity_type(spend_queue):
+ # Test with missing entity type
+ update: SpendUpdateQueueItem = {"entity_id": "123", "response_cost": 0.5}
+
+ await spend_queue.add_update(update)
+ aggregated = (
+ await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
+ )
+
+ # Should ignore updates without entity type
+ assert all(len(transactions) == 0 for transactions in aggregated.values())
diff --git a/tests/local_testing/test_update_spend.py b/tests/local_testing/test_update_spend.py
index fffa3062d7..cc2c94af27 100644
--- a/tests/local_testing/test_update_spend.py
+++ b/tests/local_testing/test_update_spend.py
@@ -62,6 +62,8 @@ from litellm.proxy._types import (
KeyRequest,
NewUserRequest,
UpdateKeyRequest,
+ SpendUpdateQueueItem,
+ Litellm_EntityType,
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@@ -93,7 +95,13 @@ def prisma_client():
@pytest.mark.asyncio
async def test_batch_update_spend(prisma_client):
- prisma_client.user_list_transactions["test-litellm-user-5"] = 23
+ await proxy_logging_obj.db_spend_update_writer.spend_update_queue.add_update(
+ SpendUpdateQueueItem(
+ entity_type=Litellm_EntityType.USER,
+ entity_id="test-litellm-user-5",
+ response_cost=23,
+ )
+ )
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py
index b28948094e..1281d50863 100644
--- a/tests/proxy_unit_tests/test_proxy_utils.py
+++ b/tests/proxy_unit_tests/test_proxy_utils.py
@@ -1485,21 +1485,18 @@ from litellm.proxy.utils import ProxyUpdateSpend
async def test_end_user_transactions_reset():
# Setup
mock_client = MagicMock()
- mock_client.end_user_list_transactions = {"1": 10.0} # Bad log
+ end_user_list_transactions = {"1": 10.0} # Bad log
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
# Call function - should raise error
with pytest.raises(Exception):
await ProxyUpdateSpend.update_end_user_spend(
- n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
+ n_retry_times=0,
+ prisma_client=mock_client,
+ proxy_logging_obj=MagicMock(),
+ end_user_list_transactions=end_user_list_transactions,
)
- # Verify cleanup happened
- assert (
- mock_client.end_user_list_transactions == {}
- ), "Transactions list should be empty after error"
-
-
@pytest.mark.asyncio
async def test_spend_logs_cleanup_after_error():
# Setup test data
diff --git a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py
index 129be6d754..46863889d2 100644
--- a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py
+++ b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py
@@ -24,9 +24,10 @@ async def test_disable_spend_logs():
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
):
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
+ db_spend_update_writer = DBSpendUpdateWriter()
# Call update_database with disable_spend_logs=True
- await DBSpendUpdateWriter.update_database(
+ await db_spend_update_writer.update_database(
token="fake-token",
response_cost=0.1,
user_id="user123",
diff --git a/tests/proxy_unit_tests/test_update_spend.py b/tests/proxy_unit_tests/test_update_spend.py
index 641768a7d2..1fb2479792 100644
--- a/tests/proxy_unit_tests/test_update_spend.py
+++ b/tests/proxy_unit_tests/test_update_spend.py
@@ -27,12 +27,6 @@ class MockPrismaClient:
# Initialize transaction lists
self.spend_log_transactions = []
- self.user_list_transactons = {}
- self.end_user_list_transactons = {}
- self.key_list_transactons = {}
- self.team_list_transactons = {}
- self.team_member_list_transactons = {}
- self.org_list_transactons = {}
self.daily_user_spend_transactions = {}
def jsonify_object(self, obj):
diff --git a/tests/otel_tests/local_test_spend_accuracy_tests.py b/tests/spend_tracking_tests/test_spend_accuracy_tests.py
similarity index 96%
rename from tests/otel_tests/local_test_spend_accuracy_tests.py
rename to tests/spend_tracking_tests/test_spend_accuracy_tests.py
index 6d756219c7..93228c2d06 100644
--- a/tests/otel_tests/local_test_spend_accuracy_tests.py
+++ b/tests/spend_tracking_tests/test_spend_accuracy_tests.py
@@ -52,7 +52,7 @@ Additional Test Scenarios:
async def create_organization(session, organization_alias: str):
"""Helper function to create a new organization"""
- url = "http://0.0.0.0:4002/organization/new"
+ url = "http://0.0.0.0:4000/organization/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"organization_alias": organization_alias}
async with session.post(url, headers=headers, json=data) as response:
@@ -61,7 +61,7 @@ async def create_organization(session, organization_alias: str):
async def create_team(session, org_id: str):
"""Helper function to create a new team under an organization"""
- url = "http://0.0.0.0:4002/team/new"
+ url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"}
async with session.post(url, headers=headers, json=data) as response:
@@ -70,7 +70,7 @@ async def create_team(session, org_id: str):
async def create_user(session, org_id: str):
"""Helper function to create a new user"""
- url = "http://0.0.0.0:4002/user/new"
+ url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"user_name": f"test-user-{uuid.uuid4()}"}
async with session.post(url, headers=headers, json=data) as response:
@@ -79,7 +79,7 @@ async def create_user(session, org_id: str):
async def generate_key(session, user_id: str, team_id: str):
"""Helper function to generate a key for a specific user and team"""
- url = "http://0.0.0.0:4002/key/generate"
+ url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {"user_id": user_id, "team_id": team_id}
async with session.post(url, headers=headers, json=data) as response:
@@ -91,7 +91,7 @@ async def chat_completion(session, key: str):
from openai import AsyncOpenAI
import uuid
- client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4002/v1")
+ client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1")
response = await client.chat.completions.create(
model="fake-openai-endpoint",
@@ -102,7 +102,7 @@ async def chat_completion(session, key: str):
async def get_spend_info(session, entity_type: str, entity_id: str):
"""Helper function to get spend information for an entity"""
- url = f"http://0.0.0.0:4002/{entity_type}/info"
+ url = f"http://0.0.0.0:4000/{entity_type}/info"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
if entity_type == "key":
data = {"key": entity_id}
@@ -156,7 +156,7 @@ async def test_basic_spend_accuracy():
response = await chat_completion(session, key)
print("response: ", response)
- # wait 10 seconds for spend to be updated
+ # wait 15 seconds for spend to be updated
await asyncio.sleep(15)
# Get spend information for each entity
@@ -235,7 +235,7 @@ async def test_long_term_spend_accuracy_with_bursts():
print(f"Burst 1 - Request {i+1}/{BURST_1_REQUESTS} completed")
# Wait for spend to be updated
- await asyncio.sleep(8)
+ await asyncio.sleep(15)
# Check intermediate spend
intermediate_key_info = await get_spend_info(session, "key", key)
@@ -248,7 +248,7 @@ async def test_long_term_spend_accuracy_with_bursts():
print(f"Burst 2 - Request {i+1}/{BURST_2_REQUESTS} completed")
# Wait for spend to be updated
- await asyncio.sleep(8)
+ await asyncio.sleep(15)
# Get final spend information for each entity
key_info = await get_spend_info(session, "key", key)