diff --git a/.circleci/config.yml b/.circleci/config.yml index e1488a9083..d2a6aafef7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1450,7 +1450,7 @@ jobs: command: | pwd ls - python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests + python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/spend_tracking_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests no_output_timeout: 120m # Store test results @@ -1743,6 +1743,96 @@ jobs: # Store test results - store_test_results: path: test-results + proxy_spend_accuracy_tests: + machine: + image: ubuntu-2204:2023.10.1 + resource_class: xlarge + working_directory: ~/project + steps: + - checkout + - setup_google_dns + - run: + name: Install Docker CLI (In case it's not already installed) + command: | + sudo apt-get update + sudo apt-get install -y docker-ce docker-ce-cli containerd.io + - run: + name: Install Python 3.9 + command: | + curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh + bash miniconda.sh -b -p $HOME/miniconda + export PATH="$HOME/miniconda/bin:$PATH" + conda init bash + source ~/.bashrc + conda create -n myenv python=3.9 -y + conda activate myenv + python --version + - run: + name: Install Dependencies + command: | + pip install "pytest==7.3.1" + pip install "pytest-asyncio==0.21.1" + pip install aiohttp + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - run: + name: Build Docker image + command: docker build -t my-app:latest -f ./docker/Dockerfile.database . + - run: + name: Run Docker container + # intentionally give bad redis credentials here + # the OTEL test - should get this as a trace + command: | + docker run -d \ + -p 4000:4000 \ + -e DATABASE_URL=$PROXY_DATABASE_URL \ + -e REDIS_HOST=$REDIS_HOST \ + -e REDIS_PASSWORD=$REDIS_PASSWORD \ + -e REDIS_PORT=$REDIS_PORT \ + -e LITELLM_MASTER_KEY="sk-1234" \ + -e OPENAI_API_KEY=$OPENAI_API_KEY \ + -e LITELLM_LICENSE=$LITELLM_LICENSE \ + -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ + -e USE_DDTRACE=True \ + -e DD_API_KEY=$DD_API_KEY \ + -e DD_SITE=$DD_SITE \ + -e AWS_REGION_NAME=$AWS_REGION_NAME \ + --name my-app \ + -v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \ + my-app:latest \ + --config /app/config.yaml \ + --port 4000 \ + --detailed_debug \ + - run: + name: Install curl and dockerize + command: | + sudo apt-get update + sudo apt-get install -y curl + sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz + sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz + sudo rm dockerize-linux-amd64-v0.6.1.tar.gz + - run: + name: Start outputting logs + command: docker logs -f my-app + background: true + - run: + name: Wait for app to be ready + command: dockerize -wait http://localhost:4000 -timeout 5m + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/spend_tracking_tests -x --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: + 120m + # Clean up first container + - run: + name: Stop and remove first container + command: | + docker stop my-app + docker rm my-app proxy_multi_instance_tests: machine: @@ -2553,6 +2643,12 @@ workflows: only: - main - /litellm_.*/ + - proxy_spend_accuracy_tests: + filters: + branches: + only: + - main + - /litellm_.*/ - proxy_multi_instance_tests: filters: branches: @@ -2714,6 +2810,7 @@ workflows: - installing_litellm_on_python - installing_litellm_on_python_3_13 - proxy_logging_guardrails_model_info_tests + - proxy_spend_accuracy_tests - proxy_multi_instance_tests - proxy_store_model_in_db_tests - proxy_build_from_pip_tests diff --git a/cookbook/misc/dev_release.txt b/cookbook/misc/dev_release.txt index 717a6da546..bd40f89e6f 100644 --- a/cookbook/misc/dev_release.txt +++ b/cookbook/misc/dev_release.txt @@ -1,2 +1,11 @@ python3 -m build -twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ - \ No newline at end of file +twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ - + + +Note: You might need to make a MANIFEST.ini file on root for build process incase it fails + +Place this in MANIFEST.ini +recursive-exclude venv * +recursive-exclude myenv * +recursive-exclude py313_env * +recursive-exclude **/.venv * diff --git a/docs/my-website/docs/enterprise.md b/docs/my-website/docs/enterprise.md index 5aeeb710ff..706ca33714 100644 --- a/docs/my-website/docs/enterprise.md +++ b/docs/my-website/docs/enterprise.md @@ -1,3 +1,5 @@ +import Image from '@theme/IdealImage'; + # Enterprise For companies that need SSO, user management and professional support for LiteLLM Proxy @@ -7,6 +9,8 @@ Get free 7-day trial key [here](https://www.litellm.ai/#trial) Includes all enterprise features. + + [**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs) diff --git a/docs/my-website/img/enterprise_vs_oss.png b/docs/my-website/img/enterprise_vs_oss.png new file mode 100644 index 0000000000..f2b58fbc14 Binary files /dev/null and b/docs/my-website/img/enterprise_vs_oss.png differ diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9536442475..aacc9f525b 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2751,3 +2751,9 @@ class DBSpendUpdateTransactions(TypedDict): team_list_transactions: Optional[Dict[str, float]] team_member_list_transactions: Optional[Dict[str, float]] org_list_transactions: Optional[Dict[str, float]] + + +class SpendUpdateQueueItem(TypedDict, total=False): + entity_type: Litellm_EntityType + entity_id: str + response_cost: Optional[float] diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index f46b03b57a..5bf255feae 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -22,9 +22,11 @@ from litellm.proxy._types import ( Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload, + SpendUpdateQueueItem, ) from litellm.proxy.db.pod_lock_manager import PodLockManager from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer +from litellm.proxy.db.spend_update_queue import SpendUpdateQueue if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient, ProxyLogging @@ -48,10 +50,11 @@ class DBSpendUpdateWriter: self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) + self.spend_update_queue = SpendUpdateQueue() - @staticmethod async def update_database( # LiteLLM management object fields + self, token: Optional[str], user_id: Optional[str], end_user_id: Optional[str], @@ -84,7 +87,7 @@ class DBSpendUpdateWriter: hashed_token = token asyncio.create_task( - DBSpendUpdateWriter._update_user_db( + self._update_user_db( response_cost=response_cost, user_id=user_id, prisma_client=prisma_client, @@ -94,14 +97,14 @@ class DBSpendUpdateWriter: ) ) asyncio.create_task( - DBSpendUpdateWriter._update_key_db( + self._update_key_db( response_cost=response_cost, hashed_token=hashed_token, prisma_client=prisma_client, ) ) asyncio.create_task( - DBSpendUpdateWriter._update_team_db( + self._update_team_db( response_cost=response_cost, team_id=team_id, user_id=user_id, @@ -109,7 +112,7 @@ class DBSpendUpdateWriter: ) ) asyncio.create_task( - DBSpendUpdateWriter._update_org_db( + self._update_org_db( response_cost=response_cost, org_id=org_id, prisma_client=prisma_client, @@ -135,56 +138,8 @@ class DBSpendUpdateWriter: f"Error updating Prisma database: {traceback.format_exc()}" ) - @staticmethod - async def _update_transaction_list( - response_cost: Optional[float], - entity_id: Optional[str], - transaction_list: dict, - entity_type: Litellm_EntityType, - debug_msg: Optional[str] = None, - prisma_client: Optional[PrismaClient] = None, - ) -> bool: - """ - Common helper method to update a transaction list for an entity - - Args: - response_cost: The cost to add - entity_id: The ID of the entity to update - transaction_list: The transaction list dictionary to update - entity_type: The type of entity (from EntityType enum) - debug_msg: Optional custom debug message - - Returns: - bool: True if update happened, False otherwise - """ - try: - if debug_msg: - verbose_proxy_logger.debug(debug_msg) - else: - verbose_proxy_logger.debug( - f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}." - ) - if prisma_client is None: - return False - - if entity_id is None: - verbose_proxy_logger.debug( - f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}" - ) - return False - transaction_list[entity_id] = response_cost + transaction_list.get( - entity_id, 0 - ) - return True - - except Exception as e: - verbose_proxy_logger.info( - f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - @staticmethod async def _update_key_db( + self, response_cost: Optional[float], hashed_token: Optional[str], prisma_client: Optional[PrismaClient], @@ -193,13 +148,12 @@ class DBSpendUpdateWriter: if hashed_token is None or prisma_client is None: return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=hashed_token, - transaction_list=prisma_client.key_list_transactions, - entity_type=Litellm_EntityType.KEY, - debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.", - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.KEY, + entity_id=hashed_token, + response_cost=response_cost, + ) ) except Exception as e: verbose_proxy_logger.exception( @@ -207,8 +161,8 @@ class DBSpendUpdateWriter: ) raise e - @staticmethod async def _update_user_db( + self, response_cost: Optional[float], user_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -234,21 +188,21 @@ class DBSpendUpdateWriter: for _id in user_ids: if _id is not None: - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=_id, - transaction_list=prisma_client.user_list_transactions, - entity_type=Litellm_EntityType.USER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.USER, + entity_id=_id, + response_cost=response_cost, + ) ) if end_user_id is not None: - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=end_user_id, - transaction_list=prisma_client.end_user_list_transactions, - entity_type=Litellm_EntityType.END_USER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.END_USER, + entity_id=end_user_id, + response_cost=response_cost, + ) ) except Exception as e: verbose_proxy_logger.info( @@ -256,8 +210,8 @@ class DBSpendUpdateWriter: + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" ) - @staticmethod async def _update_team_db( + self, response_cost: Optional[float], team_id: Optional[str], user_id: Optional[str], @@ -270,12 +224,12 @@ class DBSpendUpdateWriter: ) return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=team_id, - transaction_list=prisma_client.team_list_transactions, - entity_type=Litellm_EntityType.TEAM, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.TEAM, + entity_id=team_id, + response_cost=response_cost, + ) ) try: @@ -283,12 +237,12 @@ class DBSpendUpdateWriter: if user_id is not None: # key is "team_id::::user_id::" team_member_key = f"team_id::{team_id}::user_id::{user_id}" - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=team_member_key, - transaction_list=prisma_client.team_member_list_transactions, - entity_type=Litellm_EntityType.TEAM_MEMBER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.TEAM_MEMBER, + entity_id=team_member_key, + response_cost=response_cost, + ) ) except Exception: pass @@ -298,8 +252,8 @@ class DBSpendUpdateWriter: ) raise e - @staticmethod async def _update_org_db( + self, response_cost: Optional[float], org_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -311,12 +265,12 @@ class DBSpendUpdateWriter: ) return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=org_id, - transaction_list=prisma_client.org_list_transactions, - entity_type=Litellm_EntityType.ORGANIZATION, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update=SpendUpdateQueueItem( + entity_type=Litellm_EntityType.ORGANIZATION, + entity_id=org_id, + response_cost=response_cost, + ) ) except Exception as e: verbose_proxy_logger.info( @@ -435,7 +389,7 @@ class DBSpendUpdateWriter: - Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB) """ await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( - prisma_client=prisma_client, + spend_update_queue=self.spend_update_queue, ) # Only commit from redis to db if this pod is the leader @@ -447,7 +401,7 @@ class DBSpendUpdateWriter: await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() ) if db_spend_update_transactions is not None: - await DBSpendUpdateWriter._commit_spend_updates_to_db( + await self._commit_spend_updates_to_db( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj, @@ -471,23 +425,18 @@ class DBSpendUpdateWriter: Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. """ - db_spend_update_transactions = DBSpendUpdateTransactions( - user_list_transactions=prisma_client.user_list_transactions, - end_user_list_transactions=prisma_client.end_user_list_transactions, - key_list_transactions=prisma_client.key_list_transactions, - team_list_transactions=prisma_client.team_list_transactions, - team_member_list_transactions=prisma_client.team_member_list_transactions, - org_list_transactions=prisma_client.org_list_transactions, + db_spend_update_transactions = ( + await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() ) - await DBSpendUpdateWriter._commit_spend_updates_to_db( + await self._commit_spend_updates_to_db( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj, db_spend_update_transactions=db_spend_update_transactions, ) - @staticmethod async def _commit_spend_updates_to_db( # noqa: PLR0915 + self, prisma_client: PrismaClient, n_retry_times: int, proxy_logging_obj: ProxyLogging, @@ -526,9 +475,6 @@ class DBSpendUpdateWriter: where={"user_id": user_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.user_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -561,6 +507,7 @@ class DBSpendUpdateWriter: n_retry_times=n_retry_times, prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj, + end_user_list_transactions=end_user_list_transactions, ) ### UPDATE KEY TABLE ### key_list_transactions = db_spend_update_transactions["key_list_transactions"] @@ -583,9 +530,6 @@ class DBSpendUpdateWriter: where={"token": token}, data={"spend": {"increment": response_cost}}, ) - prisma_client.key_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -632,9 +576,6 @@ class DBSpendUpdateWriter: where={"team_id": team_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.team_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -684,9 +625,6 @@ class DBSpendUpdateWriter: where={"team_id": team_id, "user_id": user_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.team_member_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -725,9 +663,6 @@ class DBSpendUpdateWriter: where={"organization_id": org_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.org_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index f98fc9300f..1a3fd3d42d 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -12,6 +12,7 @@ from litellm.caching import RedisCache from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import DBSpendUpdateTransactions +from litellm.proxy.db.spend_update_queue import SpendUpdateQueue from litellm.secret_managers.main import str_to_bool if TYPE_CHECKING: @@ -54,7 +55,7 @@ class RedisUpdateBuffer: async def store_in_memory_spend_updates_in_redis( self, - prisma_client: PrismaClient, + spend_update_queue: SpendUpdateQueue, ): """ Stores the in-memory spend updates to Redis @@ -78,13 +79,12 @@ class RedisUpdateBuffer: "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" ) return - db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions( - user_list_transactions=prisma_client.user_list_transactions, - end_user_list_transactions=prisma_client.end_user_list_transactions, - key_list_transactions=prisma_client.key_list_transactions, - team_list_transactions=prisma_client.team_list_transactions, - team_member_list_transactions=prisma_client.team_member_list_transactions, - org_list_transactions=prisma_client.org_list_transactions, + + db_spend_update_transactions = ( + await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + verbose_proxy_logger.debug( + "ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions ) # only store in redis if there are any updates to commit @@ -100,9 +100,6 @@ class RedisUpdateBuffer: values=list_of_transactions, ) - # clear the in-memory spend updates - RedisUpdateBuffer._clear_all_in_memory_spend_updates(prisma_client) - @staticmethod def _number_of_transactions_to_store_in_redis( db_spend_update_transactions: DBSpendUpdateTransactions, @@ -116,20 +113,6 @@ class RedisUpdateBuffer: num_transactions += len(v) return num_transactions - @staticmethod - def _clear_all_in_memory_spend_updates( - prisma_client: PrismaClient, - ): - """ - Clears all in-memory spend updates - """ - prisma_client.user_list_transactions = {} - prisma_client.end_user_list_transactions = {} - prisma_client.key_list_transactions = {} - prisma_client.team_list_transactions = {} - prisma_client.team_member_list_transactions = {} - prisma_client.org_list_transactions = {} - @staticmethod def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]: """ diff --git a/litellm/proxy/db/spend_update_queue.py b/litellm/proxy/db/spend_update_queue.py new file mode 100644 index 0000000000..28e05246fa --- /dev/null +++ b/litellm/proxy/db/spend_update_queue.py @@ -0,0 +1,132 @@ +import asyncio +from typing import TYPE_CHECKING, Any, List + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + DBSpendUpdateTransactions, + Litellm_EntityType, + SpendUpdateQueueItem, +) + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient +else: + PrismaClient = Any + + +class SpendUpdateQueue: + """ + In memory buffer for spend updates that should be committed to the database + """ + + def __init__( + self, + ): + self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue() + + async def add_update(self, update: SpendUpdateQueueItem) -> None: + """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}.""" + verbose_proxy_logger.debug("Adding update to queue: %s", update) + await self.update_queue.put(update) + + async def flush_all_updates_from_in_memory_queue( + self, + ) -> List[SpendUpdateQueueItem]: + """Get all updates from the queue.""" + updates: List[SpendUpdateQueueItem] = [] + while not self.update_queue.empty(): + updates.append(await self.update_queue.get()) + return updates + + async def flush_and_get_aggregated_db_spend_update_transactions( + self, + ) -> DBSpendUpdateTransactions: + """Flush all updates from the queue and return all updates aggregated by entity type.""" + updates = await self.flush_all_updates_from_in_memory_queue() + verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates) + return self.get_aggregated_db_spend_update_transactions(updates) + + def get_aggregated_db_spend_update_transactions( + self, updates: List[SpendUpdateQueueItem] + ) -> DBSpendUpdateTransactions: + """Aggregate updates by entity type.""" + # Initialize all transaction lists as empty dicts + db_spend_update_transactions = DBSpendUpdateTransactions( + user_list_transactions={}, + end_user_list_transactions={}, + key_list_transactions={}, + team_list_transactions={}, + team_member_list_transactions={}, + org_list_transactions={}, + ) + + # Map entity types to their corresponding transaction dictionary keys + entity_type_to_dict_key = { + Litellm_EntityType.USER: "user_list_transactions", + Litellm_EntityType.END_USER: "end_user_list_transactions", + Litellm_EntityType.KEY: "key_list_transactions", + Litellm_EntityType.TEAM: "team_list_transactions", + Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions", + Litellm_EntityType.ORGANIZATION: "org_list_transactions", + } + + for update in updates: + entity_type = update.get("entity_type") + entity_id = update.get("entity_id") or "" + response_cost = update.get("response_cost") or 0 + + if entity_type is None: + verbose_proxy_logger.debug( + "Skipping update spend for update: %s, because entity_type is None", + update, + ) + continue + + dict_key = entity_type_to_dict_key.get(entity_type) + if dict_key is None: + verbose_proxy_logger.debug( + "Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key", + update, + ) + continue # Skip unknown entity types + + # Type-safe access using if/elif statements + if dict_key == "user_list_transactions": + transactions_dict = db_spend_update_transactions[ + "user_list_transactions" + ] + elif dict_key == "end_user_list_transactions": + transactions_dict = db_spend_update_transactions[ + "end_user_list_transactions" + ] + elif dict_key == "key_list_transactions": + transactions_dict = db_spend_update_transactions[ + "key_list_transactions" + ] + elif dict_key == "team_list_transactions": + transactions_dict = db_spend_update_transactions[ + "team_list_transactions" + ] + elif dict_key == "team_member_list_transactions": + transactions_dict = db_spend_update_transactions[ + "team_member_list_transactions" + ] + elif dict_key == "org_list_transactions": + transactions_dict = db_spend_update_transactions[ + "org_list_transactions" + ] + else: + continue + + if transactions_dict is None: + transactions_dict = {} + + # type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")" + db_spend_update_transactions[dict_key] = transactions_dict # type: ignore + + if entity_id not in transactions_dict: + transactions_dict[entity_id] = 0 + + transactions_dict[entity_id] += response_cost or 0 + + return db_spend_update_transactions diff --git a/litellm/proxy/example_config_yaml/spend_tracking_config.yaml b/litellm/proxy/example_config_yaml/spend_tracking_config.yaml new file mode 100644 index 0000000000..fe8d73d26a --- /dev/null +++ b/litellm/proxy/example_config_yaml/spend_tracking_config.yaml @@ -0,0 +1,15 @@ +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +general_settings: + use_redis_transaction_buffer: true + +litellm_settings: + cache: True + cache_params: + type: redis + supported_call_types: [] \ No newline at end of file diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 39c1eeace9..9ae1eb4c34 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -13,7 +13,6 @@ from litellm.litellm_core_utils.core_helpers import ( from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import log_db_metrics -from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.utils import ProxyUpdateSpend from litellm.types.utils import ( StandardLoggingPayload, @@ -37,6 +36,8 @@ class _ProxyDBLogger(CustomLogger): if _ProxyDBLogger._should_track_errors_in_db() is False: return + from litellm.proxy.proxy_server import proxy_logging_obj + _metadata = dict( StandardLoggingUserAPIKeyMetadata( user_api_key_hash=user_api_key_dict.api_key, @@ -66,7 +67,7 @@ class _ProxyDBLogger(CustomLogger): request_data.get("proxy_server_request") or {} ) request_data["litellm_params"]["metadata"] = existing_metadata - await DBSpendUpdateWriter.update_database( + await proxy_logging_obj.db_spend_update_writer.update_database( token=user_api_key_dict.api_key, response_cost=0.0, user_id=user_api_key_dict.user_id, @@ -136,7 +137,7 @@ class _ProxyDBLogger(CustomLogger): end_user_id=end_user_id, ): ## UPDATE DATABASE - await DBSpendUpdateWriter.update_database( + await proxy_logging_obj.db_spend_update_writer.update_database( token=user_api_key, response_cost=response_cost, user_id=user_id, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 9cdd1e7236..17658df903 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,9 +1,6 @@ model_list: - - model_name: gpt-4o + - model_name: fake-openai-endpoint litellm_params: - model: openai/gpt-4o - api_key: sk-xxxxxxx - -general_settings: - service_account_settings: - enforced_params: ["user"] # this means the "user" param is enforced for all requests made through any service account keys \ No newline at end of file + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f612c88ccc..67d1882a11 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1111,12 +1111,6 @@ def jsonify_object(data: dict) -> dict: class PrismaClient: - user_list_transactions: dict = {} - end_user_list_transactions: dict = {} - key_list_transactions: dict = {} - team_list_transactions: dict = {} - team_member_list_transactions: dict = {} # key is ["team_id" + "user_id"] - org_list_transactions: dict = {} spend_log_transactions: List = [] daily_user_spend_transactions: Dict[str, DailyUserSpendTransaction] = {} @@ -2479,7 +2473,10 @@ def _hash_token_if_needed(token: str) -> str: class ProxyUpdateSpend: @staticmethod async def update_end_user_spend( - n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging + n_retry_times: int, + prisma_client: PrismaClient, + proxy_logging_obj: ProxyLogging, + end_user_list_transactions: Dict[str, float], ): for i in range(n_retry_times + 1): start_time = time.time() @@ -2491,7 +2488,7 @@ class ProxyUpdateSpend: for ( end_user_id, response_cost, - ) in prisma_client.end_user_list_transactions.items(): + ) in end_user_list_transactions.items(): if litellm.max_end_user_budget is not None: pass batcher.litellm_endusertable.upsert( @@ -2518,10 +2515,6 @@ class ProxyUpdateSpend: _raise_failed_update_spend_exception( e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) - finally: - prisma_client.end_user_list_transactions = ( - {} - ) # reset the end user list transactions - prevent bad data from causing issues @staticmethod async def update_spend_logs( diff --git a/tests/litellm/proxy/db/test_spend_update_queue.py b/tests/litellm/proxy/db/test_spend_update_queue.py new file mode 100644 index 0000000000..89d494a070 --- /dev/null +++ b/tests/litellm/proxy/db/test_spend_update_queue.py @@ -0,0 +1,152 @@ +import asyncio +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem +from litellm.proxy.db.spend_update_queue import SpendUpdateQueue + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + + +@pytest.fixture +def spend_queue(): + return SpendUpdateQueue() + + +@pytest.mark.asyncio +async def test_add_update(spend_queue): + # Test adding a single update + update: SpendUpdateQueueItem = { + "entity_type": Litellm_EntityType.USER, + "entity_id": "user123", + "response_cost": 0.5, + } + await spend_queue.add_update(update) + + # Verify update was added by checking queue size + assert spend_queue.update_queue.qsize() == 1 + + +@pytest.mark.asyncio +async def test_missing_response_cost(spend_queue): + # Test with missing response_cost - should default to 0 + update: SpendUpdateQueueItem = { + "entity_type": Litellm_EntityType.USER, + "entity_id": "user123", + } + + await spend_queue.add_update(update) + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Should have created entry with 0 cost + assert aggregated["user_list_transactions"]["user123"] == 0 + + +@pytest.mark.asyncio +async def test_missing_entity_id(spend_queue): + # Test with missing entity_id - should default to empty string + update: SpendUpdateQueueItem = { + "entity_type": Litellm_EntityType.USER, + "response_cost": 1.0, + } + + await spend_queue.add_update(update) + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Should use empty string as key + assert aggregated["user_list_transactions"][""] == 1.0 + + +@pytest.mark.asyncio +async def test_none_values(spend_queue): + # Test with None values + update: SpendUpdateQueueItem = { + "entity_type": Litellm_EntityType.USER, + "entity_id": None, # type: ignore + "response_cost": None, + } + + await spend_queue.add_update(update) + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Should handle None values gracefully + assert aggregated["user_list_transactions"][""] == 0 + + +@pytest.mark.asyncio +async def test_multiple_updates_with_missing_fields(spend_queue): + # Test multiple updates with various missing fields + updates: list[SpendUpdateQueueItem] = [ + { + "entity_type": Litellm_EntityType.USER, + "entity_id": "user123", + "response_cost": 0.5, + }, + { + "entity_type": Litellm_EntityType.USER, + "entity_id": "user123", # missing response_cost + }, + { + "entity_type": Litellm_EntityType.USER, # missing entity_id + "response_cost": 1.5, + }, + ] + + for update in updates: + await spend_queue.add_update(update) + + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Verify aggregation + assert ( + aggregated["user_list_transactions"]["user123"] == 0.5 + ) # only the first update with valid cost + assert ( + aggregated["user_list_transactions"][""] == 1.5 + ) # update with missing entity_id + + +@pytest.mark.asyncio +async def test_unknown_entity_type(spend_queue): + # Test with unknown entity type + update: SpendUpdateQueueItem = { + "entity_type": "UNKNOWN_TYPE", # type: ignore + "entity_id": "123", + "response_cost": 0.5, + } + + await spend_queue.add_update(update) + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Should ignore unknown entity type + assert all(len(transactions) == 0 for transactions in aggregated.values()) + + +@pytest.mark.asyncio +async def test_missing_entity_type(spend_queue): + # Test with missing entity type + update: SpendUpdateQueueItem = {"entity_id": "123", "response_cost": 0.5} + + await spend_queue.add_update(update) + aggregated = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + + # Should ignore updates without entity type + assert all(len(transactions) == 0 for transactions in aggregated.values()) diff --git a/tests/local_testing/test_update_spend.py b/tests/local_testing/test_update_spend.py index fffa3062d7..cc2c94af27 100644 --- a/tests/local_testing/test_update_spend.py +++ b/tests/local_testing/test_update_spend.py @@ -62,6 +62,8 @@ from litellm.proxy._types import ( KeyRequest, NewUserRequest, UpdateKeyRequest, + SpendUpdateQueueItem, + Litellm_EntityType, ) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) @@ -93,7 +95,13 @@ def prisma_client(): @pytest.mark.asyncio async def test_batch_update_spend(prisma_client): - prisma_client.user_list_transactions["test-litellm-user-5"] = 23 + await proxy_logging_obj.db_spend_update_writer.spend_update_queue.add_update( + SpendUpdateQueueItem( + entity_type=Litellm_EntityType.USER, + entity_id="test-litellm-user-5", + response_cost=23, + ) + ) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") await litellm.proxy.proxy_server.prisma_client.connect() diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index b28948094e..1281d50863 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1485,21 +1485,18 @@ from litellm.proxy.utils import ProxyUpdateSpend async def test_end_user_transactions_reset(): # Setup mock_client = MagicMock() - mock_client.end_user_list_transactions = {"1": 10.0} # Bad log + end_user_list_transactions = {"1": 10.0} # Bad log mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error")) # Call function - should raise error with pytest.raises(Exception): await ProxyUpdateSpend.update_end_user_spend( - n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock() + n_retry_times=0, + prisma_client=mock_client, + proxy_logging_obj=MagicMock(), + end_user_list_transactions=end_user_list_transactions, ) - # Verify cleanup happened - assert ( - mock_client.end_user_list_transactions == {} - ), "Transactions list should be empty after error" - - @pytest.mark.asyncio async def test_spend_logs_cleanup_after_error(): # Setup test data diff --git a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py index 129be6d754..46863889d2 100644 --- a/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py +++ b/tests/proxy_unit_tests/test_unit_test_proxy_hooks.py @@ -24,9 +24,10 @@ async def test_disable_spend_logs(): "litellm.proxy.proxy_server.prisma_client", mock_prisma_client ): from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter + db_spend_update_writer = DBSpendUpdateWriter() # Call update_database with disable_spend_logs=True - await DBSpendUpdateWriter.update_database( + await db_spend_update_writer.update_database( token="fake-token", response_cost=0.1, user_id="user123", diff --git a/tests/proxy_unit_tests/test_update_spend.py b/tests/proxy_unit_tests/test_update_spend.py index 641768a7d2..1fb2479792 100644 --- a/tests/proxy_unit_tests/test_update_spend.py +++ b/tests/proxy_unit_tests/test_update_spend.py @@ -27,12 +27,6 @@ class MockPrismaClient: # Initialize transaction lists self.spend_log_transactions = [] - self.user_list_transactons = {} - self.end_user_list_transactons = {} - self.key_list_transactons = {} - self.team_list_transactons = {} - self.team_member_list_transactons = {} - self.org_list_transactons = {} self.daily_user_spend_transactions = {} def jsonify_object(self, obj): diff --git a/tests/otel_tests/local_test_spend_accuracy_tests.py b/tests/spend_tracking_tests/test_spend_accuracy_tests.py similarity index 96% rename from tests/otel_tests/local_test_spend_accuracy_tests.py rename to tests/spend_tracking_tests/test_spend_accuracy_tests.py index 6d756219c7..93228c2d06 100644 --- a/tests/otel_tests/local_test_spend_accuracy_tests.py +++ b/tests/spend_tracking_tests/test_spend_accuracy_tests.py @@ -52,7 +52,7 @@ Additional Test Scenarios: async def create_organization(session, organization_alias: str): """Helper function to create a new organization""" - url = "http://0.0.0.0:4002/organization/new" + url = "http://0.0.0.0:4000/organization/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = {"organization_alias": organization_alias} async with session.post(url, headers=headers, json=data) as response: @@ -61,7 +61,7 @@ async def create_organization(session, organization_alias: str): async def create_team(session, org_id: str): """Helper function to create a new team under an organization""" - url = "http://0.0.0.0:4002/team/new" + url = "http://0.0.0.0:4000/team/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"} async with session.post(url, headers=headers, json=data) as response: @@ -70,7 +70,7 @@ async def create_team(session, org_id: str): async def create_user(session, org_id: str): """Helper function to create a new user""" - url = "http://0.0.0.0:4002/user/new" + url = "http://0.0.0.0:4000/user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = {"user_name": f"test-user-{uuid.uuid4()}"} async with session.post(url, headers=headers, json=data) as response: @@ -79,7 +79,7 @@ async def create_user(session, org_id: str): async def generate_key(session, user_id: str, team_id: str): """Helper function to generate a key for a specific user and team""" - url = "http://0.0.0.0:4002/key/generate" + url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = {"user_id": user_id, "team_id": team_id} async with session.post(url, headers=headers, json=data) as response: @@ -91,7 +91,7 @@ async def chat_completion(session, key: str): from openai import AsyncOpenAI import uuid - client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4002/v1") + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1") response = await client.chat.completions.create( model="fake-openai-endpoint", @@ -102,7 +102,7 @@ async def chat_completion(session, key: str): async def get_spend_info(session, entity_type: str, entity_id: str): """Helper function to get spend information for an entity""" - url = f"http://0.0.0.0:4002/{entity_type}/info" + url = f"http://0.0.0.0:4000/{entity_type}/info" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} if entity_type == "key": data = {"key": entity_id} @@ -156,7 +156,7 @@ async def test_basic_spend_accuracy(): response = await chat_completion(session, key) print("response: ", response) - # wait 10 seconds for spend to be updated + # wait 15 seconds for spend to be updated await asyncio.sleep(15) # Get spend information for each entity @@ -235,7 +235,7 @@ async def test_long_term_spend_accuracy_with_bursts(): print(f"Burst 1 - Request {i+1}/{BURST_1_REQUESTS} completed") # Wait for spend to be updated - await asyncio.sleep(8) + await asyncio.sleep(15) # Check intermediate spend intermediate_key_info = await get_spend_info(session, "key", key) @@ -248,7 +248,7 @@ async def test_long_term_spend_accuracy_with_bursts(): print(f"Burst 2 - Request {i+1}/{BURST_2_REQUESTS} completed") # Wait for spend to be updated - await asyncio.sleep(8) + await asyncio.sleep(15) # Get final spend information for each entity key_info = await get_spend_info(session, "key", key)