mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_fix_service_account_behavior
This commit is contained in:
commit
4ddca7a79c
18 changed files with 518 additions and 194 deletions
|
@ -1450,7 +1450,7 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
pwd
|
pwd
|
||||||
ls
|
ls
|
||||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/spend_tracking_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
||||||
no_output_timeout: 120m
|
no_output_timeout: 120m
|
||||||
|
|
||||||
# Store test results
|
# Store test results
|
||||||
|
@ -1743,6 +1743,96 @@ jobs:
|
||||||
# Store test results
|
# Store test results
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
|
proxy_spend_accuracy_tests:
|
||||||
|
machine:
|
||||||
|
image: ubuntu-2204:2023.10.1
|
||||||
|
resource_class: xlarge
|
||||||
|
working_directory: ~/project
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- setup_google_dns
|
||||||
|
- run:
|
||||||
|
name: Install Docker CLI (In case it's not already installed)
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
|
||||||
|
- run:
|
||||||
|
name: Install Python 3.9
|
||||||
|
command: |
|
||||||
|
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
|
||||||
|
bash miniconda.sh -b -p $HOME/miniconda
|
||||||
|
export PATH="$HOME/miniconda/bin:$PATH"
|
||||||
|
conda init bash
|
||||||
|
source ~/.bashrc
|
||||||
|
conda create -n myenv python=3.9 -y
|
||||||
|
conda activate myenv
|
||||||
|
python --version
|
||||||
|
- run:
|
||||||
|
name: Install Dependencies
|
||||||
|
command: |
|
||||||
|
pip install "pytest==7.3.1"
|
||||||
|
pip install "pytest-asyncio==0.21.1"
|
||||||
|
pip install aiohttp
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
- run:
|
||||||
|
name: Build Docker image
|
||||||
|
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||||
|
- run:
|
||||||
|
name: Run Docker container
|
||||||
|
# intentionally give bad redis credentials here
|
||||||
|
# the OTEL test - should get this as a trace
|
||||||
|
command: |
|
||||||
|
docker run -d \
|
||||||
|
-p 4000:4000 \
|
||||||
|
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||||
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
-e REDIS_PORT=$REDIS_PORT \
|
||||||
|
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||||
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
|
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||||
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
|
-e USE_DDTRACE=True \
|
||||||
|
-e DD_API_KEY=$DD_API_KEY \
|
||||||
|
-e DD_SITE=$DD_SITE \
|
||||||
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
|
--name my-app \
|
||||||
|
-v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \
|
||||||
|
my-app:latest \
|
||||||
|
--config /app/config.yaml \
|
||||||
|
--port 4000 \
|
||||||
|
--detailed_debug \
|
||||||
|
- run:
|
||||||
|
name: Install curl and dockerize
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y curl
|
||||||
|
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
|
||||||
|
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
|
||||||
|
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
|
||||||
|
- run:
|
||||||
|
name: Start outputting logs
|
||||||
|
command: docker logs -f my-app
|
||||||
|
background: true
|
||||||
|
- run:
|
||||||
|
name: Wait for app to be ready
|
||||||
|
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||||
|
- run:
|
||||||
|
name: Run tests
|
||||||
|
command: |
|
||||||
|
pwd
|
||||||
|
ls
|
||||||
|
python -m pytest -vv tests/spend_tracking_tests -x --junitxml=test-results/junit.xml --durations=5
|
||||||
|
no_output_timeout:
|
||||||
|
120m
|
||||||
|
# Clean up first container
|
||||||
|
- run:
|
||||||
|
name: Stop and remove first container
|
||||||
|
command: |
|
||||||
|
docker stop my-app
|
||||||
|
docker rm my-app
|
||||||
|
|
||||||
proxy_multi_instance_tests:
|
proxy_multi_instance_tests:
|
||||||
machine:
|
machine:
|
||||||
|
@ -2553,6 +2643,12 @@ workflows:
|
||||||
only:
|
only:
|
||||||
- main
|
- main
|
||||||
- /litellm_.*/
|
- /litellm_.*/
|
||||||
|
- proxy_spend_accuracy_tests:
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- main
|
||||||
|
- /litellm_.*/
|
||||||
- proxy_multi_instance_tests:
|
- proxy_multi_instance_tests:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
|
@ -2714,6 +2810,7 @@ workflows:
|
||||||
- installing_litellm_on_python
|
- installing_litellm_on_python
|
||||||
- installing_litellm_on_python_3_13
|
- installing_litellm_on_python_3_13
|
||||||
- proxy_logging_guardrails_model_info_tests
|
- proxy_logging_guardrails_model_info_tests
|
||||||
|
- proxy_spend_accuracy_tests
|
||||||
- proxy_multi_instance_tests
|
- proxy_multi_instance_tests
|
||||||
- proxy_store_model_in_db_tests
|
- proxy_store_model_in_db_tests
|
||||||
- proxy_build_from_pip_tests
|
- proxy_build_from_pip_tests
|
||||||
|
|
|
@ -1,2 +1,11 @@
|
||||||
python3 -m build
|
python3 -m build
|
||||||
twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
|
twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
|
||||||
|
|
||||||
|
|
||||||
|
Note: You might need to make a MANIFEST.ini file on root for build process incase it fails
|
||||||
|
|
||||||
|
Place this in MANIFEST.ini
|
||||||
|
recursive-exclude venv *
|
||||||
|
recursive-exclude myenv *
|
||||||
|
recursive-exclude py313_env *
|
||||||
|
recursive-exclude **/.venv *
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Enterprise
|
# Enterprise
|
||||||
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
||||||
|
|
||||||
|
@ -7,6 +9,8 @@ Get free 7-day trial key [here](https://www.litellm.ai/#trial)
|
||||||
|
|
||||||
Includes all enterprise features.
|
Includes all enterprise features.
|
||||||
|
|
||||||
|
<Image img={require('../img/enterprise_vs_oss.png')} />
|
||||||
|
|
||||||
[**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs)
|
[**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs)
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
docs/my-website/img/enterprise_vs_oss.png
Normal file
BIN
docs/my-website/img/enterprise_vs_oss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 61 KiB |
|
@ -2751,3 +2751,9 @@ class DBSpendUpdateTransactions(TypedDict):
|
||||||
team_list_transactions: Optional[Dict[str, float]]
|
team_list_transactions: Optional[Dict[str, float]]
|
||||||
team_member_list_transactions: Optional[Dict[str, float]]
|
team_member_list_transactions: Optional[Dict[str, float]]
|
||||||
org_list_transactions: Optional[Dict[str, float]]
|
org_list_transactions: Optional[Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
class SpendUpdateQueueItem(TypedDict, total=False):
|
||||||
|
entity_type: Litellm_EntityType
|
||||||
|
entity_id: str
|
||||||
|
response_cost: Optional[float]
|
||||||
|
|
|
@ -22,9 +22,11 @@ from litellm.proxy._types import (
|
||||||
Litellm_EntityType,
|
Litellm_EntityType,
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
|
SpendUpdateQueueItem,
|
||||||
)
|
)
|
||||||
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||||
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
||||||
|
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||||
|
@ -48,10 +50,11 @@ class DBSpendUpdateWriter:
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
||||||
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||||
|
self.spend_update_queue = SpendUpdateQueue()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_database(
|
async def update_database(
|
||||||
# LiteLLM management object fields
|
# LiteLLM management object fields
|
||||||
|
self,
|
||||||
token: Optional[str],
|
token: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
|
@ -84,7 +87,7 @@ class DBSpendUpdateWriter:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_user_db(
|
self._update_user_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
|
@ -94,14 +97,14 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_key_db(
|
self._update_key_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
hashed_token=hashed_token,
|
hashed_token=hashed_token,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_team_db(
|
self._update_team_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -109,7 +112,7 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
DBSpendUpdateWriter._update_org_db(
|
self._update_org_db(
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
|
@ -135,56 +138,8 @@ class DBSpendUpdateWriter:
|
||||||
f"Error updating Prisma database: {traceback.format_exc()}"
|
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_transaction_list(
|
|
||||||
response_cost: Optional[float],
|
|
||||||
entity_id: Optional[str],
|
|
||||||
transaction_list: dict,
|
|
||||||
entity_type: Litellm_EntityType,
|
|
||||||
debug_msg: Optional[str] = None,
|
|
||||||
prisma_client: Optional[PrismaClient] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Common helper method to update a transaction list for an entity
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_cost: The cost to add
|
|
||||||
entity_id: The ID of the entity to update
|
|
||||||
transaction_list: The transaction list dictionary to update
|
|
||||||
entity_type: The type of entity (from EntityType enum)
|
|
||||||
debug_msg: Optional custom debug message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if update happened, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if debug_msg:
|
|
||||||
verbose_proxy_logger.debug(debug_msg)
|
|
||||||
else:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}."
|
|
||||||
)
|
|
||||||
if prisma_client is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if entity_id is None:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
transaction_list[entity_id] = response_cost + transaction_list.get(
|
|
||||||
entity_id, 0
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_key_db(
|
async def _update_key_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
hashed_token: Optional[str],
|
hashed_token: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -193,13 +148,12 @@ class DBSpendUpdateWriter:
|
||||||
if hashed_token is None or prisma_client is None:
|
if hashed_token is None or prisma_client is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=hashed_token,
|
|
||||||
transaction_list=prisma_client.key_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.KEY,
|
entity_type=Litellm_EntityType.KEY,
|
||||||
debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.",
|
entity_id=hashed_token,
|
||||||
prisma_client=prisma_client,
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
|
@ -207,8 +161,8 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_user_db(
|
async def _update_user_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -234,21 +188,21 @@ class DBSpendUpdateWriter:
|
||||||
|
|
||||||
for _id in user_ids:
|
for _id in user_ids:
|
||||||
if _id is not None:
|
if _id is not None:
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=_id,
|
|
||||||
transaction_list=prisma_client.user_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.USER,
|
entity_type=Litellm_EntityType.USER,
|
||||||
prisma_client=prisma_client,
|
entity_id=_id,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=end_user_id,
|
|
||||||
transaction_list=prisma_client.end_user_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.END_USER,
|
entity_type=Litellm_EntityType.END_USER,
|
||||||
prisma_client=prisma_client,
|
entity_id=end_user_id,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
|
@ -256,8 +210,8 @@ class DBSpendUpdateWriter:
|
||||||
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_team_db(
|
async def _update_team_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
team_id: Optional[str],
|
team_id: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
|
@ -270,12 +224,12 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=team_id,
|
|
||||||
transaction_list=prisma_client.team_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.TEAM,
|
entity_type=Litellm_EntityType.TEAM,
|
||||||
prisma_client=prisma_client,
|
entity_id=team_id,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -283,12 +237,12 @@ class DBSpendUpdateWriter:
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
# key is "team_id::<value>::user_id::<value>"
|
# key is "team_id::<value>::user_id::<value>"
|
||||||
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=team_member_key,
|
|
||||||
transaction_list=prisma_client.team_member_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.TEAM_MEMBER,
|
entity_type=Litellm_EntityType.TEAM_MEMBER,
|
||||||
prisma_client=prisma_client,
|
entity_id=team_member_key,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -298,8 +252,8 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _update_org_db(
|
async def _update_org_db(
|
||||||
|
self,
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
org_id: Optional[str],
|
org_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -311,12 +265,12 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await DBSpendUpdateWriter._update_transaction_list(
|
await self.spend_update_queue.add_update(
|
||||||
response_cost=response_cost,
|
update=SpendUpdateQueueItem(
|
||||||
entity_id=org_id,
|
|
||||||
transaction_list=prisma_client.org_list_transactions,
|
|
||||||
entity_type=Litellm_EntityType.ORGANIZATION,
|
entity_type=Litellm_EntityType.ORGANIZATION,
|
||||||
prisma_client=prisma_client,
|
entity_id=org_id,
|
||||||
|
response_cost=response_cost,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
|
@ -435,7 +389,7 @@ class DBSpendUpdateWriter:
|
||||||
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
|
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
|
||||||
"""
|
"""
|
||||||
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||||
prisma_client=prisma_client,
|
spend_update_queue=self.spend_update_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only commit from redis to db if this pod is the leader
|
# Only commit from redis to db if this pod is the leader
|
||||||
|
@ -447,7 +401,7 @@ class DBSpendUpdateWriter:
|
||||||
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
|
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
|
||||||
)
|
)
|
||||||
if db_spend_update_transactions is not None:
|
if db_spend_update_transactions is not None:
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
await self._commit_spend_updates_to_db(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
@ -471,23 +425,18 @@ class DBSpendUpdateWriter:
|
||||||
|
|
||||||
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
|
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
|
||||||
"""
|
"""
|
||||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
db_spend_update_transactions = (
|
||||||
user_list_transactions=prisma_client.user_list_transactions,
|
await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
end_user_list_transactions=prisma_client.end_user_list_transactions,
|
|
||||||
key_list_transactions=prisma_client.key_list_transactions,
|
|
||||||
team_list_transactions=prisma_client.team_list_transactions,
|
|
||||||
team_member_list_transactions=prisma_client.team_member_list_transactions,
|
|
||||||
org_list_transactions=prisma_client.org_list_transactions,
|
|
||||||
)
|
)
|
||||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
await self._commit_spend_updates_to_db(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
db_spend_update_transactions=db_spend_update_transactions,
|
db_spend_update_transactions=db_spend_update_transactions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _commit_spend_updates_to_db( # noqa: PLR0915
|
async def _commit_spend_updates_to_db( # noqa: PLR0915
|
||||||
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
n_retry_times: int,
|
n_retry_times: int,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
@ -526,9 +475,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"user_id": user_id},
|
where={"user_id": user_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.user_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -561,6 +507,7 @@ class DBSpendUpdateWriter:
|
||||||
n_retry_times=n_retry_times,
|
n_retry_times=n_retry_times,
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
end_user_list_transactions=end_user_list_transactions,
|
||||||
)
|
)
|
||||||
### UPDATE KEY TABLE ###
|
### UPDATE KEY TABLE ###
|
||||||
key_list_transactions = db_spend_update_transactions["key_list_transactions"]
|
key_list_transactions = db_spend_update_transactions["key_list_transactions"]
|
||||||
|
@ -583,9 +530,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"token": token},
|
where={"token": token},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.key_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -632,9 +576,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"team_id": team_id},
|
where={"team_id": team_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.team_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -684,9 +625,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"team_id": team_id, "user_id": user_id},
|
where={"team_id": team_id, "user_id": user_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.team_member_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
@ -725,9 +663,6 @@ class DBSpendUpdateWriter:
|
||||||
where={"organization_id": org_id},
|
where={"organization_id": org_id},
|
||||||
data={"spend": {"increment": response_cost}},
|
data={"spend": {"increment": response_cost}},
|
||||||
)
|
)
|
||||||
prisma_client.org_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # Clear the remaining transactions after processing all batches in the loop.
|
|
||||||
break
|
break
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -12,6 +12,7 @@ from litellm.caching import RedisCache
|
||||||
from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY
|
from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY
|
||||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
from litellm.proxy._types import DBSpendUpdateTransactions
|
from litellm.proxy._types import DBSpendUpdateTransactions
|
||||||
|
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -54,7 +55,7 @@ class RedisUpdateBuffer:
|
||||||
|
|
||||||
async def store_in_memory_spend_updates_in_redis(
|
async def store_in_memory_spend_updates_in_redis(
|
||||||
self,
|
self,
|
||||||
prisma_client: PrismaClient,
|
spend_update_queue: SpendUpdateQueue,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stores the in-memory spend updates to Redis
|
Stores the in-memory spend updates to Redis
|
||||||
|
@ -78,13 +79,12 @@ class RedisUpdateBuffer:
|
||||||
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions(
|
|
||||||
user_list_transactions=prisma_client.user_list_transactions,
|
db_spend_update_transactions = (
|
||||||
end_user_list_transactions=prisma_client.end_user_list_transactions,
|
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
key_list_transactions=prisma_client.key_list_transactions,
|
)
|
||||||
team_list_transactions=prisma_client.team_list_transactions,
|
verbose_proxy_logger.debug(
|
||||||
team_member_list_transactions=prisma_client.team_member_list_transactions,
|
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
|
||||||
org_list_transactions=prisma_client.org_list_transactions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# only store in redis if there are any updates to commit
|
# only store in redis if there are any updates to commit
|
||||||
|
@ -100,9 +100,6 @@ class RedisUpdateBuffer:
|
||||||
values=list_of_transactions,
|
values=list_of_transactions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# clear the in-memory spend updates
|
|
||||||
RedisUpdateBuffer._clear_all_in_memory_spend_updates(prisma_client)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _number_of_transactions_to_store_in_redis(
|
def _number_of_transactions_to_store_in_redis(
|
||||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||||
|
@ -116,20 +113,6 @@ class RedisUpdateBuffer:
|
||||||
num_transactions += len(v)
|
num_transactions += len(v)
|
||||||
return num_transactions
|
return num_transactions
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _clear_all_in_memory_spend_updates(
|
|
||||||
prisma_client: PrismaClient,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Clears all in-memory spend updates
|
|
||||||
"""
|
|
||||||
prisma_client.user_list_transactions = {}
|
|
||||||
prisma_client.end_user_list_transactions = {}
|
|
||||||
prisma_client.key_list_transactions = {}
|
|
||||||
prisma_client.team_list_transactions = {}
|
|
||||||
prisma_client.team_member_list_transactions = {}
|
|
||||||
prisma_client.org_list_transactions = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
132
litellm/proxy/db/spend_update_queue.py
Normal file
132
litellm/proxy/db/spend_update_queue.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import TYPE_CHECKING, Any, List
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
DBSpendUpdateTransactions,
|
||||||
|
Litellm_EntityType,
|
||||||
|
SpendUpdateQueueItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
else:
|
||||||
|
PrismaClient = Any
|
||||||
|
|
||||||
|
|
||||||
|
class SpendUpdateQueue:
|
||||||
|
"""
|
||||||
|
In memory buffer for spend updates that should be committed to the database
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
|
||||||
|
|
||||||
|
async def add_update(self, update: SpendUpdateQueueItem) -> None:
|
||||||
|
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
|
||||||
|
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||||
|
await self.update_queue.put(update)
|
||||||
|
|
||||||
|
async def flush_all_updates_from_in_memory_queue(
|
||||||
|
self,
|
||||||
|
) -> List[SpendUpdateQueueItem]:
|
||||||
|
"""Get all updates from the queue."""
|
||||||
|
updates: List[SpendUpdateQueueItem] = []
|
||||||
|
while not self.update_queue.empty():
|
||||||
|
updates.append(await self.update_queue.get())
|
||||||
|
return updates
|
||||||
|
|
||||||
|
async def flush_and_get_aggregated_db_spend_update_transactions(
|
||||||
|
self,
|
||||||
|
) -> DBSpendUpdateTransactions:
|
||||||
|
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||||
|
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||||
|
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||||
|
return self.get_aggregated_db_spend_update_transactions(updates)
|
||||||
|
|
||||||
|
def get_aggregated_db_spend_update_transactions(
|
||||||
|
self, updates: List[SpendUpdateQueueItem]
|
||||||
|
) -> DBSpendUpdateTransactions:
|
||||||
|
"""Aggregate updates by entity type."""
|
||||||
|
# Initialize all transaction lists as empty dicts
|
||||||
|
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||||
|
user_list_transactions={},
|
||||||
|
end_user_list_transactions={},
|
||||||
|
key_list_transactions={},
|
||||||
|
team_list_transactions={},
|
||||||
|
team_member_list_transactions={},
|
||||||
|
org_list_transactions={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map entity types to their corresponding transaction dictionary keys
|
||||||
|
entity_type_to_dict_key = {
|
||||||
|
Litellm_EntityType.USER: "user_list_transactions",
|
||||||
|
Litellm_EntityType.END_USER: "end_user_list_transactions",
|
||||||
|
Litellm_EntityType.KEY: "key_list_transactions",
|
||||||
|
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||||
|
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||||
|
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||||
|
}
|
||||||
|
|
||||||
|
for update in updates:
|
||||||
|
entity_type = update.get("entity_type")
|
||||||
|
entity_id = update.get("entity_id") or ""
|
||||||
|
response_cost = update.get("response_cost") or 0
|
||||||
|
|
||||||
|
if entity_type is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Skipping update spend for update: %s, because entity_type is None",
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
dict_key = entity_type_to_dict_key.get(entity_type)
|
||||||
|
if dict_key is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
|
||||||
|
update,
|
||||||
|
)
|
||||||
|
continue # Skip unknown entity types
|
||||||
|
|
||||||
|
# Type-safe access using if/elif statements
|
||||||
|
if dict_key == "user_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"user_list_transactions"
|
||||||
|
]
|
||||||
|
elif dict_key == "end_user_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"end_user_list_transactions"
|
||||||
|
]
|
||||||
|
elif dict_key == "key_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"key_list_transactions"
|
||||||
|
]
|
||||||
|
elif dict_key == "team_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"team_list_transactions"
|
||||||
|
]
|
||||||
|
elif dict_key == "team_member_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"team_member_list_transactions"
|
||||||
|
]
|
||||||
|
elif dict_key == "org_list_transactions":
|
||||||
|
transactions_dict = db_spend_update_transactions[
|
||||||
|
"org_list_transactions"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if transactions_dict is None:
|
||||||
|
transactions_dict = {}
|
||||||
|
|
||||||
|
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
|
||||||
|
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
|
||||||
|
|
||||||
|
if entity_id not in transactions_dict:
|
||||||
|
transactions_dict[entity_id] = 0
|
||||||
|
|
||||||
|
transactions_dict[entity_id] += response_cost or 0
|
||||||
|
|
||||||
|
return db_spend_update_transactions
|
15
litellm/proxy/example_config_yaml/spend_tracking_config.yaml
Normal file
15
litellm/proxy/example_config_yaml/spend_tracking_config.yaml
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
use_redis_transaction_buffer: true
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
cache: True
|
||||||
|
cache_params:
|
||||||
|
type: redis
|
||||||
|
supported_call_types: []
|
|
@ -13,7 +13,6 @@ from litellm.litellm_core_utils.core_helpers import (
|
||||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
|
||||||
from litellm.proxy.utils import ProxyUpdateSpend
|
from litellm.proxy.utils import ProxyUpdateSpend
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
|
@ -37,6 +36,8 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
_metadata = dict(
|
_metadata = dict(
|
||||||
StandardLoggingUserAPIKeyMetadata(
|
StandardLoggingUserAPIKeyMetadata(
|
||||||
user_api_key_hash=user_api_key_dict.api_key,
|
user_api_key_hash=user_api_key_dict.api_key,
|
||||||
|
@ -66,7 +67,7 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
request_data.get("proxy_server_request") or {}
|
request_data.get("proxy_server_request") or {}
|
||||||
)
|
)
|
||||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||||
await DBSpendUpdateWriter.update_database(
|
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||||
token=user_api_key_dict.api_key,
|
token=user_api_key_dict.api_key,
|
||||||
response_cost=0.0,
|
response_cost=0.0,
|
||||||
user_id=user_api_key_dict.user_id,
|
user_id=user_api_key_dict.user_id,
|
||||||
|
@ -136,7 +137,7 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
):
|
):
|
||||||
## UPDATE DATABASE
|
## UPDATE DATABASE
|
||||||
await DBSpendUpdateWriter.update_database(
|
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||||
token=user_api_key,
|
token=user_api_key,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-4o
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o
|
model: openai/fake
|
||||||
api_key: sk-xxxxxxx
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
general_settings:
|
|
||||||
service_account_settings:
|
|
||||||
enforced_params: ["user"] # this means the "user" param is enforced for all requests made through any service account keys
|
|
||||||
|
|
|
@ -1111,12 +1111,6 @@ def jsonify_object(data: dict) -> dict:
|
||||||
|
|
||||||
|
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
user_list_transactions: dict = {}
|
|
||||||
end_user_list_transactions: dict = {}
|
|
||||||
key_list_transactions: dict = {}
|
|
||||||
team_list_transactions: dict = {}
|
|
||||||
team_member_list_transactions: dict = {} # key is ["team_id" + "user_id"]
|
|
||||||
org_list_transactions: dict = {}
|
|
||||||
spend_log_transactions: List = []
|
spend_log_transactions: List = []
|
||||||
daily_user_spend_transactions: Dict[str, DailyUserSpendTransaction] = {}
|
daily_user_spend_transactions: Dict[str, DailyUserSpendTransaction] = {}
|
||||||
|
|
||||||
|
@ -2479,7 +2473,10 @@ def _hash_token_if_needed(token: str) -> str:
|
||||||
class ProxyUpdateSpend:
|
class ProxyUpdateSpend:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_end_user_spend(
|
async def update_end_user_spend(
|
||||||
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
|
n_retry_times: int,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
end_user_list_transactions: Dict[str, float],
|
||||||
):
|
):
|
||||||
for i in range(n_retry_times + 1):
|
for i in range(n_retry_times + 1):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -2491,7 +2488,7 @@ class ProxyUpdateSpend:
|
||||||
for (
|
for (
|
||||||
end_user_id,
|
end_user_id,
|
||||||
response_cost,
|
response_cost,
|
||||||
) in prisma_client.end_user_list_transactions.items():
|
) in end_user_list_transactions.items():
|
||||||
if litellm.max_end_user_budget is not None:
|
if litellm.max_end_user_budget is not None:
|
||||||
pass
|
pass
|
||||||
batcher.litellm_endusertable.upsert(
|
batcher.litellm_endusertable.upsert(
|
||||||
|
@ -2518,10 +2515,6 @@ class ProxyUpdateSpend:
|
||||||
_raise_failed_update_spend_exception(
|
_raise_failed_update_spend_exception(
|
||||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
prisma_client.end_user_list_transactions = (
|
|
||||||
{}
|
|
||||||
) # reset the end user list transactions - prevent bad data from causing issues
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_spend_logs(
|
async def update_spend_logs(
|
||||||
|
|
152
tests/litellm/proxy/db/test_spend_update_queue.py
Normal file
152
tests/litellm/proxy/db/test_spend_update_queue.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
|
||||||
|
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def spend_queue():
|
||||||
|
return SpendUpdateQueue()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_update(spend_queue):
|
||||||
|
# Test adding a single update
|
||||||
|
update: SpendUpdateQueueItem = {
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"entity_id": "user123",
|
||||||
|
"response_cost": 0.5,
|
||||||
|
}
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
|
||||||
|
# Verify update was added by checking queue size
|
||||||
|
assert spend_queue.update_queue.qsize() == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_response_cost(spend_queue):
|
||||||
|
# Test with missing response_cost - should default to 0
|
||||||
|
update: SpendUpdateQueueItem = {
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"entity_id": "user123",
|
||||||
|
}
|
||||||
|
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have created entry with 0 cost
|
||||||
|
assert aggregated["user_list_transactions"]["user123"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_entity_id(spend_queue):
|
||||||
|
# Test with missing entity_id - should default to empty string
|
||||||
|
update: SpendUpdateQueueItem = {
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"response_cost": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use empty string as key
|
||||||
|
assert aggregated["user_list_transactions"][""] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_none_values(spend_queue):
|
||||||
|
# Test with None values
|
||||||
|
update: SpendUpdateQueueItem = {
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"entity_id": None, # type: ignore
|
||||||
|
"response_cost": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle None values gracefully
|
||||||
|
assert aggregated["user_list_transactions"][""] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_updates_with_missing_fields(spend_queue):
|
||||||
|
# Test multiple updates with various missing fields
|
||||||
|
updates: list[SpendUpdateQueueItem] = [
|
||||||
|
{
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"entity_id": "user123",
|
||||||
|
"response_cost": 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_type": Litellm_EntityType.USER,
|
||||||
|
"entity_id": "user123", # missing response_cost
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_type": Litellm_EntityType.USER, # missing entity_id
|
||||||
|
"response_cost": 1.5,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for update in updates:
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify aggregation
|
||||||
|
assert (
|
||||||
|
aggregated["user_list_transactions"]["user123"] == 0.5
|
||||||
|
) # only the first update with valid cost
|
||||||
|
assert (
|
||||||
|
aggregated["user_list_transactions"][""] == 1.5
|
||||||
|
) # update with missing entity_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_entity_type(spend_queue):
|
||||||
|
# Test with unknown entity type
|
||||||
|
update: SpendUpdateQueueItem = {
|
||||||
|
"entity_type": "UNKNOWN_TYPE", # type: ignore
|
||||||
|
"entity_id": "123",
|
||||||
|
"response_cost": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should ignore unknown entity type
|
||||||
|
assert all(len(transactions) == 0 for transactions in aggregated.values())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_entity_type(spend_queue):
|
||||||
|
# Test with missing entity type
|
||||||
|
update: SpendUpdateQueueItem = {"entity_id": "123", "response_cost": 0.5}
|
||||||
|
|
||||||
|
await spend_queue.add_update(update)
|
||||||
|
aggregated = (
|
||||||
|
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should ignore updates without entity type
|
||||||
|
assert all(len(transactions) == 0 for transactions in aggregated.values())
|
|
@ -62,6 +62,8 @@ from litellm.proxy._types import (
|
||||||
KeyRequest,
|
KeyRequest,
|
||||||
NewUserRequest,
|
NewUserRequest,
|
||||||
UpdateKeyRequest,
|
UpdateKeyRequest,
|
||||||
|
SpendUpdateQueueItem,
|
||||||
|
Litellm_EntityType,
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
@ -93,7 +95,13 @@ def prisma_client():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_update_spend(prisma_client):
|
async def test_batch_update_spend(prisma_client):
|
||||||
prisma_client.user_list_transactions["test-litellm-user-5"] = 23
|
await proxy_logging_obj.db_spend_update_writer.spend_update_queue.add_update(
|
||||||
|
SpendUpdateQueueItem(
|
||||||
|
entity_type=Litellm_EntityType.USER,
|
||||||
|
entity_id="test-litellm-user-5",
|
||||||
|
response_cost=23,
|
||||||
|
)
|
||||||
|
)
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
|
@ -1485,21 +1485,18 @@ from litellm.proxy.utils import ProxyUpdateSpend
|
||||||
async def test_end_user_transactions_reset():
|
async def test_end_user_transactions_reset():
|
||||||
# Setup
|
# Setup
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client.end_user_list_transactions = {"1": 10.0} # Bad log
|
end_user_list_transactions = {"1": 10.0} # Bad log
|
||||||
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
|
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
|
||||||
|
|
||||||
# Call function - should raise error
|
# Call function - should raise error
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await ProxyUpdateSpend.update_end_user_spend(
|
await ProxyUpdateSpend.update_end_user_spend(
|
||||||
n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
|
n_retry_times=0,
|
||||||
|
prisma_client=mock_client,
|
||||||
|
proxy_logging_obj=MagicMock(),
|
||||||
|
end_user_list_transactions=end_user_list_transactions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cleanup happened
|
|
||||||
assert (
|
|
||||||
mock_client.end_user_list_transactions == {}
|
|
||||||
), "Transactions list should be empty after error"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_spend_logs_cleanup_after_error():
|
async def test_spend_logs_cleanup_after_error():
|
||||||
# Setup test data
|
# Setup test data
|
||||||
|
|
|
@ -24,9 +24,10 @@ async def test_disable_spend_logs():
|
||||||
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
|
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
|
||||||
):
|
):
|
||||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||||
|
db_spend_update_writer = DBSpendUpdateWriter()
|
||||||
|
|
||||||
# Call update_database with disable_spend_logs=True
|
# Call update_database with disable_spend_logs=True
|
||||||
await DBSpendUpdateWriter.update_database(
|
await db_spend_update_writer.update_database(
|
||||||
token="fake-token",
|
token="fake-token",
|
||||||
response_cost=0.1,
|
response_cost=0.1,
|
||||||
user_id="user123",
|
user_id="user123",
|
||||||
|
|
|
@ -27,12 +27,6 @@ class MockPrismaClient:
|
||||||
|
|
||||||
# Initialize transaction lists
|
# Initialize transaction lists
|
||||||
self.spend_log_transactions = []
|
self.spend_log_transactions = []
|
||||||
self.user_list_transactons = {}
|
|
||||||
self.end_user_list_transactons = {}
|
|
||||||
self.key_list_transactons = {}
|
|
||||||
self.team_list_transactons = {}
|
|
||||||
self.team_member_list_transactons = {}
|
|
||||||
self.org_list_transactons = {}
|
|
||||||
self.daily_user_spend_transactions = {}
|
self.daily_user_spend_transactions = {}
|
||||||
|
|
||||||
def jsonify_object(self, obj):
|
def jsonify_object(self, obj):
|
||||||
|
|
|
@ -52,7 +52,7 @@ Additional Test Scenarios:
|
||||||
|
|
||||||
async def create_organization(session, organization_alias: str):
|
async def create_organization(session, organization_alias: str):
|
||||||
"""Helper function to create a new organization"""
|
"""Helper function to create a new organization"""
|
||||||
url = "http://0.0.0.0:4002/organization/new"
|
url = "http://0.0.0.0:4000/organization/new"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {"organization_alias": organization_alias}
|
data = {"organization_alias": organization_alias}
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
@ -61,7 +61,7 @@ async def create_organization(session, organization_alias: str):
|
||||||
|
|
||||||
async def create_team(session, org_id: str):
|
async def create_team(session, org_id: str):
|
||||||
"""Helper function to create a new team under an organization"""
|
"""Helper function to create a new team under an organization"""
|
||||||
url = "http://0.0.0.0:4002/team/new"
|
url = "http://0.0.0.0:4000/team/new"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"}
|
data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"}
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
@ -70,7 +70,7 @@ async def create_team(session, org_id: str):
|
||||||
|
|
||||||
async def create_user(session, org_id: str):
|
async def create_user(session, org_id: str):
|
||||||
"""Helper function to create a new user"""
|
"""Helper function to create a new user"""
|
||||||
url = "http://0.0.0.0:4002/user/new"
|
url = "http://0.0.0.0:4000/user/new"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {"user_name": f"test-user-{uuid.uuid4()}"}
|
data = {"user_name": f"test-user-{uuid.uuid4()}"}
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
@ -79,7 +79,7 @@ async def create_user(session, org_id: str):
|
||||||
|
|
||||||
async def generate_key(session, user_id: str, team_id: str):
|
async def generate_key(session, user_id: str, team_id: str):
|
||||||
"""Helper function to generate a key for a specific user and team"""
|
"""Helper function to generate a key for a specific user and team"""
|
||||||
url = "http://0.0.0.0:4002/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {"user_id": user_id, "team_id": team_id}
|
data = {"user_id": user_id, "team_id": team_id}
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
@ -91,7 +91,7 @@ async def chat_completion(session, key: str):
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4002/v1")
|
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model="fake-openai-endpoint",
|
model="fake-openai-endpoint",
|
||||||
|
@ -102,7 +102,7 @@ async def chat_completion(session, key: str):
|
||||||
|
|
||||||
async def get_spend_info(session, entity_type: str, entity_id: str):
|
async def get_spend_info(session, entity_type: str, entity_id: str):
|
||||||
"""Helper function to get spend information for an entity"""
|
"""Helper function to get spend information for an entity"""
|
||||||
url = f"http://0.0.0.0:4002/{entity_type}/info"
|
url = f"http://0.0.0.0:4000/{entity_type}/info"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
if entity_type == "key":
|
if entity_type == "key":
|
||||||
data = {"key": entity_id}
|
data = {"key": entity_id}
|
||||||
|
@ -156,7 +156,7 @@ async def test_basic_spend_accuracy():
|
||||||
response = await chat_completion(session, key)
|
response = await chat_completion(session, key)
|
||||||
print("response: ", response)
|
print("response: ", response)
|
||||||
|
|
||||||
# wait 10 seconds for spend to be updated
|
# wait 15 seconds for spend to be updated
|
||||||
await asyncio.sleep(15)
|
await asyncio.sleep(15)
|
||||||
|
|
||||||
# Get spend information for each entity
|
# Get spend information for each entity
|
||||||
|
@ -235,7 +235,7 @@ async def test_long_term_spend_accuracy_with_bursts():
|
||||||
print(f"Burst 1 - Request {i+1}/{BURST_1_REQUESTS} completed")
|
print(f"Burst 1 - Request {i+1}/{BURST_1_REQUESTS} completed")
|
||||||
|
|
||||||
# Wait for spend to be updated
|
# Wait for spend to be updated
|
||||||
await asyncio.sleep(8)
|
await asyncio.sleep(15)
|
||||||
|
|
||||||
# Check intermediate spend
|
# Check intermediate spend
|
||||||
intermediate_key_info = await get_spend_info(session, "key", key)
|
intermediate_key_info = await get_spend_info(session, "key", key)
|
||||||
|
@ -248,7 +248,7 @@ async def test_long_term_spend_accuracy_with_bursts():
|
||||||
print(f"Burst 2 - Request {i+1}/{BURST_2_REQUESTS} completed")
|
print(f"Burst 2 - Request {i+1}/{BURST_2_REQUESTS} completed")
|
||||||
|
|
||||||
# Wait for spend to be updated
|
# Wait for spend to be updated
|
||||||
await asyncio.sleep(8)
|
await asyncio.sleep(15)
|
||||||
|
|
||||||
# Get final spend information for each entity
|
# Get final spend information for each entity
|
||||||
key_info = await get_spend_info(session, "key", key)
|
key_info = await get_spend_info(session, "key", key)
|
Loading…
Add table
Add a link
Reference in a new issue