mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge branch 'main' into litellm_fix_service_account_behavior
This commit is contained in:
commit
4ddca7a79c
18 changed files with 518 additions and 194 deletions
|
@ -1450,7 +1450,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
||||
python -m pytest -s -vv tests/*.py -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/spend_tracking_tests --ignore=tests/pass_through_tests --ignore=tests/proxy_admin_ui_tests --ignore=tests/load_tests --ignore=tests/llm_translation --ignore=tests/llm_responses_api_testing --ignore=tests/mcp_tests --ignore=tests/image_gen_tests --ignore=tests/pass_through_unit_tests
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
|
@ -1743,6 +1743,96 @@ jobs:
|
|||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
proxy_spend_accuracy_tests:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
resource_class: xlarge
|
||||
working_directory: ~/project
|
||||
steps:
|
||||
- checkout
|
||||
- setup_google_dns
|
||||
- run:
|
||||
name: Install Docker CLI (In case it's not already installed)
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
|
||||
- run:
|
||||
name: Install Python 3.9
|
||||
command: |
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
|
||||
bash miniconda.sh -b -p $HOME/miniconda
|
||||
export PATH="$HOME/miniconda/bin:$PATH"
|
||||
conda init bash
|
||||
source ~/.bashrc
|
||||
conda create -n myenv python=3.9 -y
|
||||
conda activate myenv
|
||||
python --version
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
- run:
|
||||
name: Run Docker container
|
||||
# intentionally give bad redis credentials here
|
||||
# the OTEL test - should get this as a trace
|
||||
command: |
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
-e USE_DDTRACE=True \
|
||||
-e DD_API_KEY=$DD_API_KEY \
|
||||
-e DD_SITE=$DD_SITE \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \
|
||||
my-app:latest \
|
||||
--config /app/config.yaml \
|
||||
--port 4000 \
|
||||
--detailed_debug \
|
||||
- run:
|
||||
name: Install curl and dockerize
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y curl
|
||||
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
|
||||
- run:
|
||||
name: Start outputting logs
|
||||
command: docker logs -f my-app
|
||||
background: true
|
||||
- run:
|
||||
name: Wait for app to be ready
|
||||
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/spend_tracking_tests -x --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout:
|
||||
120m
|
||||
# Clean up first container
|
||||
- run:
|
||||
name: Stop and remove first container
|
||||
command: |
|
||||
docker stop my-app
|
||||
docker rm my-app
|
||||
|
||||
proxy_multi_instance_tests:
|
||||
machine:
|
||||
|
@ -2553,6 +2643,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- proxy_spend_accuracy_tests:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- proxy_multi_instance_tests:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -2714,6 +2810,7 @@ workflows:
|
|||
- installing_litellm_on_python
|
||||
- installing_litellm_on_python_3_13
|
||||
- proxy_logging_guardrails_model_info_tests
|
||||
- proxy_spend_accuracy_tests
|
||||
- proxy_multi_instance_tests
|
||||
- proxy_store_model_in_db_tests
|
||||
- proxy_build_from_pip_tests
|
||||
|
|
|
@ -1,2 +1,11 @@
|
|||
python3 -m build
|
||||
twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
|
||||
twine upload --verbose dist/litellm-1.18.13.dev4.tar.gz -u __token__ -
|
||||
|
||||
|
||||
Note: You might need to make a MANIFEST.ini file on root for build process incase it fails
|
||||
|
||||
Place this in MANIFEST.ini
|
||||
recursive-exclude venv *
|
||||
recursive-exclude myenv *
|
||||
recursive-exclude py313_env *
|
||||
recursive-exclude **/.venv *
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Enterprise
|
||||
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
||||
|
||||
|
@ -7,6 +9,8 @@ Get free 7-day trial key [here](https://www.litellm.ai/#trial)
|
|||
|
||||
Includes all enterprise features.
|
||||
|
||||
<Image img={require('../img/enterprise_vs_oss.png')} />
|
||||
|
||||
[**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs)
|
||||
|
||||
|
||||
|
|
BIN
docs/my-website/img/enterprise_vs_oss.png
Normal file
BIN
docs/my-website/img/enterprise_vs_oss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 61 KiB |
|
@ -2751,3 +2751,9 @@ class DBSpendUpdateTransactions(TypedDict):
|
|||
team_list_transactions: Optional[Dict[str, float]]
|
||||
team_member_list_transactions: Optional[Dict[str, float]]
|
||||
org_list_transactions: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class SpendUpdateQueueItem(TypedDict, total=False):
|
||||
entity_type: Litellm_EntityType
|
||||
entity_id: str
|
||||
response_cost: Optional[float]
|
||||
|
|
|
@ -22,9 +22,11 @@ from litellm.proxy._types import (
|
|||
Litellm_EntityType,
|
||||
LiteLLM_UserTable,
|
||||
SpendLogsPayload,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
|
||||
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
@ -48,10 +50,11 @@ class DBSpendUpdateWriter:
|
|||
self.redis_cache = redis_cache
|
||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
||||
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||
self.spend_update_queue = SpendUpdateQueue()
|
||||
|
||||
@staticmethod
|
||||
async def update_database(
|
||||
# LiteLLM management object fields
|
||||
self,
|
||||
token: Optional[str],
|
||||
user_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
|
@ -84,7 +87,7 @@ class DBSpendUpdateWriter:
|
|||
hashed_token = token
|
||||
|
||||
asyncio.create_task(
|
||||
DBSpendUpdateWriter._update_user_db(
|
||||
self._update_user_db(
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
|
@ -94,14 +97,14 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
DBSpendUpdateWriter._update_key_db(
|
||||
self._update_key_db(
|
||||
response_cost=response_cost,
|
||||
hashed_token=hashed_token,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
DBSpendUpdateWriter._update_team_db(
|
||||
self._update_team_db(
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
|
@ -109,7 +112,7 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
DBSpendUpdateWriter._update_org_db(
|
||||
self._update_org_db(
|
||||
response_cost=response_cost,
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
|
@ -135,56 +138,8 @@ class DBSpendUpdateWriter:
|
|||
f"Error updating Prisma database: {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _update_transaction_list(
|
||||
response_cost: Optional[float],
|
||||
entity_id: Optional[str],
|
||||
transaction_list: dict,
|
||||
entity_type: Litellm_EntityType,
|
||||
debug_msg: Optional[str] = None,
|
||||
prisma_client: Optional[PrismaClient] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Common helper method to update a transaction list for an entity
|
||||
|
||||
Args:
|
||||
response_cost: The cost to add
|
||||
entity_id: The ID of the entity to update
|
||||
transaction_list: The transaction list dictionary to update
|
||||
entity_type: The type of entity (from EntityType enum)
|
||||
debug_msg: Optional custom debug message
|
||||
|
||||
Returns:
|
||||
bool: True if update happened, False otherwise
|
||||
"""
|
||||
try:
|
||||
if debug_msg:
|
||||
verbose_proxy_logger.debug(debug_msg)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}."
|
||||
)
|
||||
if prisma_client is None:
|
||||
return False
|
||||
|
||||
if entity_id is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}"
|
||||
)
|
||||
return False
|
||||
transaction_list[entity_id] = response_cost + transaction_list.get(
|
||||
entity_id, 0
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def _update_key_db(
|
||||
self,
|
||||
response_cost: Optional[float],
|
||||
hashed_token: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -193,13 +148,12 @@ class DBSpendUpdateWriter:
|
|||
if hashed_token is None or prisma_client is None:
|
||||
return
|
||||
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=hashed_token,
|
||||
transaction_list=prisma_client.key_list_transactions,
|
||||
entity_type=Litellm_EntityType.KEY,
|
||||
debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.",
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.KEY,
|
||||
entity_id=hashed_token,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
|
@ -207,8 +161,8 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_db(
|
||||
self,
|
||||
response_cost: Optional[float],
|
||||
user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -234,21 +188,21 @@ class DBSpendUpdateWriter:
|
|||
|
||||
for _id in user_ids:
|
||||
if _id is not None:
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=_id,
|
||||
transaction_list=prisma_client.user_list_transactions,
|
||||
entity_type=Litellm_EntityType.USER,
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.USER,
|
||||
entity_id=_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
|
||||
if end_user_id is not None:
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=end_user_id,
|
||||
transaction_list=prisma_client.end_user_list_transactions,
|
||||
entity_type=Litellm_EntityType.END_USER,
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.END_USER,
|
||||
entity_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
|
@ -256,8 +210,8 @@ class DBSpendUpdateWriter:
|
|||
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _update_team_db(
|
||||
self,
|
||||
response_cost: Optional[float],
|
||||
team_id: Optional[str],
|
||||
user_id: Optional[str],
|
||||
|
@ -270,12 +224,12 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
return
|
||||
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=team_id,
|
||||
transaction_list=prisma_client.team_list_transactions,
|
||||
entity_type=Litellm_EntityType.TEAM,
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.TEAM,
|
||||
entity_id=team_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -283,12 +237,12 @@ class DBSpendUpdateWriter:
|
|||
if user_id is not None:
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=team_member_key,
|
||||
transaction_list=prisma_client.team_member_list_transactions,
|
||||
entity_type=Litellm_EntityType.TEAM_MEMBER,
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.TEAM_MEMBER,
|
||||
entity_id=team_member_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -298,8 +252,8 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def _update_org_db(
|
||||
self,
|
||||
response_cost: Optional[float],
|
||||
org_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -311,12 +265,12 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
return
|
||||
|
||||
await DBSpendUpdateWriter._update_transaction_list(
|
||||
response_cost=response_cost,
|
||||
entity_id=org_id,
|
||||
transaction_list=prisma_client.org_list_transactions,
|
||||
entity_type=Litellm_EntityType.ORGANIZATION,
|
||||
prisma_client=prisma_client,
|
||||
await self.spend_update_queue.add_update(
|
||||
update=SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.ORGANIZATION,
|
||||
entity_id=org_id,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
|
@ -435,7 +389,7 @@ class DBSpendUpdateWriter:
|
|||
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
|
||||
"""
|
||||
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||
prisma_client=prisma_client,
|
||||
spend_update_queue=self.spend_update_queue,
|
||||
)
|
||||
|
||||
# Only commit from redis to db if this pod is the leader
|
||||
|
@ -447,7 +401,7 @@ class DBSpendUpdateWriter:
|
|||
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
|
||||
)
|
||||
if db_spend_update_transactions is not None:
|
||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
||||
await self._commit_spend_updates_to_db(
|
||||
prisma_client=prisma_client,
|
||||
n_retry_times=n_retry_times,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
|
@ -471,23 +425,18 @@ class DBSpendUpdateWriter:
|
|||
|
||||
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
|
||||
"""
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions=prisma_client.user_list_transactions,
|
||||
end_user_list_transactions=prisma_client.end_user_list_transactions,
|
||||
key_list_transactions=prisma_client.key_list_transactions,
|
||||
team_list_transactions=prisma_client.team_list_transactions,
|
||||
team_member_list_transactions=prisma_client.team_member_list_transactions,
|
||||
org_list_transactions=prisma_client.org_list_transactions,
|
||||
db_spend_update_transactions = (
|
||||
await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
await DBSpendUpdateWriter._commit_spend_updates_to_db(
|
||||
await self._commit_spend_updates_to_db(
|
||||
prisma_client=prisma_client,
|
||||
n_retry_times=n_retry_times,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
db_spend_update_transactions=db_spend_update_transactions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _commit_spend_updates_to_db( # noqa: PLR0915
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
n_retry_times: int,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
|
@ -526,9 +475,6 @@ class DBSpendUpdateWriter:
|
|||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.user_list_transactions = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
|
@ -561,6 +507,7 @@ class DBSpendUpdateWriter:
|
|||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
end_user_list_transactions=end_user_list_transactions,
|
||||
)
|
||||
### UPDATE KEY TABLE ###
|
||||
key_list_transactions = db_spend_update_transactions["key_list_transactions"]
|
||||
|
@ -583,9 +530,6 @@ class DBSpendUpdateWriter:
|
|||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.key_list_transactions = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
|
@ -632,9 +576,6 @@ class DBSpendUpdateWriter:
|
|||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_list_transactions = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
|
@ -684,9 +625,6 @@ class DBSpendUpdateWriter:
|
|||
where={"team_id": team_id, "user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_member_list_transactions = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
|
@ -725,9 +663,6 @@ class DBSpendUpdateWriter:
|
|||
where={"organization_id": org_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.org_list_transactions = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
|
|
|
@ -12,6 +12,7 @@ from litellm.caching import RedisCache
|
|||
from litellm.constants import MAX_REDIS_BUFFER_DEQUEUE_COUNT, REDIS_UPDATE_BUFFER_KEY
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import DBSpendUpdateTransactions
|
||||
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -54,7 +55,7 @@ class RedisUpdateBuffer:
|
|||
|
||||
async def store_in_memory_spend_updates_in_redis(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
spend_update_queue: SpendUpdateQueue,
|
||||
):
|
||||
"""
|
||||
Stores the in-memory spend updates to Redis
|
||||
|
@ -78,13 +79,12 @@ class RedisUpdateBuffer:
|
|||
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
||||
)
|
||||
return
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions=prisma_client.user_list_transactions,
|
||||
end_user_list_transactions=prisma_client.end_user_list_transactions,
|
||||
key_list_transactions=prisma_client.key_list_transactions,
|
||||
team_list_transactions=prisma_client.team_list_transactions,
|
||||
team_member_list_transactions=prisma_client.team_member_list_transactions,
|
||||
org_list_transactions=prisma_client.org_list_transactions,
|
||||
|
||||
db_spend_update_transactions = (
|
||||
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
|
||||
)
|
||||
|
||||
# only store in redis if there are any updates to commit
|
||||
|
@ -100,9 +100,6 @@ class RedisUpdateBuffer:
|
|||
values=list_of_transactions,
|
||||
)
|
||||
|
||||
# clear the in-memory spend updates
|
||||
RedisUpdateBuffer._clear_all_in_memory_spend_updates(prisma_client)
|
||||
|
||||
@staticmethod
|
||||
def _number_of_transactions_to_store_in_redis(
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||
|
@ -116,20 +113,6 @@ class RedisUpdateBuffer:
|
|||
num_transactions += len(v)
|
||||
return num_transactions
|
||||
|
||||
@staticmethod
|
||||
def _clear_all_in_memory_spend_updates(
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
"""
|
||||
Clears all in-memory spend updates
|
||||
"""
|
||||
prisma_client.user_list_transactions = {}
|
||||
prisma_client.end_user_list_transactions = {}
|
||||
prisma_client.key_list_transactions = {}
|
||||
prisma_client.team_list_transactions = {}
|
||||
prisma_client.team_member_list_transactions = {}
|
||||
prisma_client.org_list_transactions = {}
|
||||
|
||||
@staticmethod
|
||||
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
132
litellm/proxy/db/spend_update_queue.py
Normal file
132
litellm/proxy/db/spend_update_queue.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class SpendUpdateQueue:
|
||||
"""
|
||||
In memory buffer for spend updates that should be committed to the database
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
|
||||
|
||||
async def add_update(self, update: SpendUpdateQueueItem) -> None:
|
||||
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
|
||||
async def flush_all_updates_from_in_memory_queue(
|
||||
self,
|
||||
) -> List[SpendUpdateQueueItem]:
|
||||
"""Get all updates from the queue."""
|
||||
updates: List[SpendUpdateQueueItem] = []
|
||||
while not self.update_queue.empty():
|
||||
updates.append(await self.update_queue.get())
|
||||
return updates
|
||||
|
||||
async def flush_and_get_aggregated_db_spend_update_transactions(
|
||||
self,
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||
return self.get_aggregated_db_spend_update_transactions(updates)
|
||||
|
||||
def get_aggregated_db_spend_update_transactions(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Aggregate updates by entity type."""
|
||||
# Initialize all transaction lists as empty dicts
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
)
|
||||
|
||||
# Map entity types to their corresponding transaction dictionary keys
|
||||
entity_type_to_dict_key = {
|
||||
Litellm_EntityType.USER: "user_list_transactions",
|
||||
Litellm_EntityType.END_USER: "end_user_list_transactions",
|
||||
Litellm_EntityType.KEY: "key_list_transactions",
|
||||
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||
}
|
||||
|
||||
for update in updates:
|
||||
entity_type = update.get("entity_type")
|
||||
entity_id = update.get("entity_id") or ""
|
||||
response_cost = update.get("response_cost") or 0
|
||||
|
||||
if entity_type is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is None",
|
||||
update,
|
||||
)
|
||||
continue
|
||||
|
||||
dict_key = entity_type_to_dict_key.get(entity_type)
|
||||
if dict_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
|
||||
update,
|
||||
)
|
||||
continue # Skip unknown entity types
|
||||
|
||||
# Type-safe access using if/elif statements
|
||||
if dict_key == "user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"user_list_transactions"
|
||||
]
|
||||
elif dict_key == "end_user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"end_user_list_transactions"
|
||||
]
|
||||
elif dict_key == "key_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"key_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_member_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_member_list_transactions"
|
||||
]
|
||||
elif dict_key == "org_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"org_list_transactions"
|
||||
]
|
||||
else:
|
||||
continue
|
||||
|
||||
if transactions_dict is None:
|
||||
transactions_dict = {}
|
||||
|
||||
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
|
||||
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
|
||||
|
||||
if entity_id not in transactions_dict:
|
||||
transactions_dict[entity_id] = 0
|
||||
|
||||
transactions_dict[entity_id] += response_cost or 0
|
||||
|
||||
return db_spend_update_transactions
|
15
litellm/proxy/example_config_yaml/spend_tracking_config.yaml
Normal file
15
litellm/proxy/example_config_yaml/spend_tracking_config.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
use_redis_transaction_buffer: true
|
||||
|
||||
litellm_settings:
|
||||
cache: True
|
||||
cache_params:
|
||||
type: redis
|
||||
supported_call_types: []
|
|
@ -13,7 +13,6 @@ from litellm.litellm_core_utils.core_helpers import (
|
|||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
|
@ -37,6 +36,8 @@ class _ProxyDBLogger(CustomLogger):
|
|||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
|
@ -66,7 +67,7 @@ class _ProxyDBLogger(CustomLogger):
|
|||
request_data.get("proxy_server_request") or {}
|
||||
)
|
||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||
await DBSpendUpdateWriter.update_database(
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key_dict.api_key,
|
||||
response_cost=0.0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
|
@ -136,7 +137,7 @@ class _ProxyDBLogger(CustomLogger):
|
|||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await DBSpendUpdateWriter.update_database(
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
model_list:
|
||||
- model_name: gpt-4o
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: sk-xxxxxxx
|
||||
|
||||
general_settings:
|
||||
service_account_settings:
|
||||
enforced_params: ["user"] # this means the "user" param is enforced for all requests made through any service account keys
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
|
|
@ -1111,12 +1111,6 @@ def jsonify_object(data: dict) -> dict:
|
|||
|
||||
|
||||
class PrismaClient:
|
||||
user_list_transactions: dict = {}
|
||||
end_user_list_transactions: dict = {}
|
||||
key_list_transactions: dict = {}
|
||||
team_list_transactions: dict = {}
|
||||
team_member_list_transactions: dict = {} # key is ["team_id" + "user_id"]
|
||||
org_list_transactions: dict = {}
|
||||
spend_log_transactions: List = []
|
||||
daily_user_spend_transactions: Dict[str, DailyUserSpendTransaction] = {}
|
||||
|
||||
|
@ -2479,7 +2473,10 @@ def _hash_token_if_needed(token: str) -> str:
|
|||
class ProxyUpdateSpend:
|
||||
@staticmethod
|
||||
async def update_end_user_spend(
|
||||
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
|
||||
n_retry_times: int,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
end_user_list_transactions: Dict[str, float],
|
||||
):
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
|
@ -2491,7 +2488,7 @@ class ProxyUpdateSpend:
|
|||
for (
|
||||
end_user_id,
|
||||
response_cost,
|
||||
) in prisma_client.end_user_list_transactions.items():
|
||||
) in end_user_list_transactions.items():
|
||||
if litellm.max_end_user_budget is not None:
|
||||
pass
|
||||
batcher.litellm_endusertable.upsert(
|
||||
|
@ -2518,10 +2515,6 @@ class ProxyUpdateSpend:
|
|||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
finally:
|
||||
prisma_client.end_user_list_transactions = (
|
||||
{}
|
||||
) # reset the end user list transactions - prevent bad data from causing issues
|
||||
|
||||
@staticmethod
|
||||
async def update_spend_logs(
|
||||
|
|
152
tests/litellm/proxy/db/test_spend_update_queue.py
Normal file
152
tests/litellm/proxy/db/test_spend_update_queue.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from litellm.proxy._types import Litellm_EntityType, SpendUpdateQueueItem
|
||||
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spend_queue():
|
||||
return SpendUpdateQueue()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_update(spend_queue):
|
||||
# Test adding a single update
|
||||
update: SpendUpdateQueueItem = {
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"entity_id": "user123",
|
||||
"response_cost": 0.5,
|
||||
}
|
||||
await spend_queue.add_update(update)
|
||||
|
||||
# Verify update was added by checking queue size
|
||||
assert spend_queue.update_queue.qsize() == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_response_cost(spend_queue):
|
||||
# Test with missing response_cost - should default to 0
|
||||
update: SpendUpdateQueueItem = {
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"entity_id": "user123",
|
||||
}
|
||||
|
||||
await spend_queue.add_update(update)
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Should have created entry with 0 cost
|
||||
assert aggregated["user_list_transactions"]["user123"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_entity_id(spend_queue):
|
||||
# Test with missing entity_id - should default to empty string
|
||||
update: SpendUpdateQueueItem = {
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"response_cost": 1.0,
|
||||
}
|
||||
|
||||
await spend_queue.add_update(update)
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Should use empty string as key
|
||||
assert aggregated["user_list_transactions"][""] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_values(spend_queue):
|
||||
# Test with None values
|
||||
update: SpendUpdateQueueItem = {
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"entity_id": None, # type: ignore
|
||||
"response_cost": None,
|
||||
}
|
||||
|
||||
await spend_queue.add_update(update)
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Should handle None values gracefully
|
||||
assert aggregated["user_list_transactions"][""] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_updates_with_missing_fields(spend_queue):
|
||||
# Test multiple updates with various missing fields
|
||||
updates: list[SpendUpdateQueueItem] = [
|
||||
{
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"entity_id": "user123",
|
||||
"response_cost": 0.5,
|
||||
},
|
||||
{
|
||||
"entity_type": Litellm_EntityType.USER,
|
||||
"entity_id": "user123", # missing response_cost
|
||||
},
|
||||
{
|
||||
"entity_type": Litellm_EntityType.USER, # missing entity_id
|
||||
"response_cost": 1.5,
|
||||
},
|
||||
]
|
||||
|
||||
for update in updates:
|
||||
await spend_queue.add_update(update)
|
||||
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Verify aggregation
|
||||
assert (
|
||||
aggregated["user_list_transactions"]["user123"] == 0.5
|
||||
) # only the first update with valid cost
|
||||
assert (
|
||||
aggregated["user_list_transactions"][""] == 1.5
|
||||
) # update with missing entity_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_entity_type(spend_queue):
|
||||
# Test with unknown entity type
|
||||
update: SpendUpdateQueueItem = {
|
||||
"entity_type": "UNKNOWN_TYPE", # type: ignore
|
||||
"entity_id": "123",
|
||||
"response_cost": 0.5,
|
||||
}
|
||||
|
||||
await spend_queue.add_update(update)
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Should ignore unknown entity type
|
||||
assert all(len(transactions) == 0 for transactions in aggregated.values())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_entity_type(spend_queue):
|
||||
# Test with missing entity type
|
||||
update: SpendUpdateQueueItem = {"entity_id": "123", "response_cost": 0.5}
|
||||
|
||||
await spend_queue.add_update(update)
|
||||
aggregated = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
|
||||
# Should ignore updates without entity type
|
||||
assert all(len(transactions) == 0 for transactions in aggregated.values())
|
|
@ -62,6 +62,8 @@ from litellm.proxy._types import (
|
|||
KeyRequest,
|
||||
NewUserRequest,
|
||||
UpdateKeyRequest,
|
||||
SpendUpdateQueueItem,
|
||||
Litellm_EntityType,
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
@ -93,7 +95,13 @@ def prisma_client():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_update_spend(prisma_client):
|
||||
prisma_client.user_list_transactions["test-litellm-user-5"] = 23
|
||||
await proxy_logging_obj.db_spend_update_writer.spend_update_queue.add_update(
|
||||
SpendUpdateQueueItem(
|
||||
entity_type=Litellm_EntityType.USER,
|
||||
entity_id="test-litellm-user-5",
|
||||
response_cost=23,
|
||||
)
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
|
|
@ -1485,21 +1485,18 @@ from litellm.proxy.utils import ProxyUpdateSpend
|
|||
async def test_end_user_transactions_reset():
|
||||
# Setup
|
||||
mock_client = MagicMock()
|
||||
mock_client.end_user_list_transactions = {"1": 10.0} # Bad log
|
||||
end_user_list_transactions = {"1": 10.0} # Bad log
|
||||
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
|
||||
|
||||
# Call function - should raise error
|
||||
with pytest.raises(Exception):
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
|
||||
n_retry_times=0,
|
||||
prisma_client=mock_client,
|
||||
proxy_logging_obj=MagicMock(),
|
||||
end_user_list_transactions=end_user_list_transactions,
|
||||
)
|
||||
|
||||
# Verify cleanup happened
|
||||
assert (
|
||||
mock_client.end_user_list_transactions == {}
|
||||
), "Transactions list should be empty after error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_logs_cleanup_after_error():
|
||||
# Setup test data
|
||||
|
|
|
@ -24,9 +24,10 @@ async def test_disable_spend_logs():
|
|||
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
|
||||
):
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
db_spend_update_writer = DBSpendUpdateWriter()
|
||||
|
||||
# Call update_database with disable_spend_logs=True
|
||||
await DBSpendUpdateWriter.update_database(
|
||||
await db_spend_update_writer.update_database(
|
||||
token="fake-token",
|
||||
response_cost=0.1,
|
||||
user_id="user123",
|
||||
|
|
|
@ -27,12 +27,6 @@ class MockPrismaClient:
|
|||
|
||||
# Initialize transaction lists
|
||||
self.spend_log_transactions = []
|
||||
self.user_list_transactons = {}
|
||||
self.end_user_list_transactons = {}
|
||||
self.key_list_transactons = {}
|
||||
self.team_list_transactons = {}
|
||||
self.team_member_list_transactons = {}
|
||||
self.org_list_transactons = {}
|
||||
self.daily_user_spend_transactions = {}
|
||||
|
||||
def jsonify_object(self, obj):
|
||||
|
|
|
@ -52,7 +52,7 @@ Additional Test Scenarios:
|
|||
|
||||
async def create_organization(session, organization_alias: str):
|
||||
"""Helper function to create a new organization"""
|
||||
url = "http://0.0.0.0:4002/organization/new"
|
||||
url = "http://0.0.0.0:4000/organization/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"organization_alias": organization_alias}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
|
@ -61,7 +61,7 @@ async def create_organization(session, organization_alias: str):
|
|||
|
||||
async def create_team(session, org_id: str):
|
||||
"""Helper function to create a new team under an organization"""
|
||||
url = "http://0.0.0.0:4002/team/new"
|
||||
url = "http://0.0.0.0:4000/team/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
|
@ -70,7 +70,7 @@ async def create_team(session, org_id: str):
|
|||
|
||||
async def create_user(session, org_id: str):
|
||||
"""Helper function to create a new user"""
|
||||
url = "http://0.0.0.0:4002/user/new"
|
||||
url = "http://0.0.0.0:4000/user/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"user_name": f"test-user-{uuid.uuid4()}"}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
|
@ -79,7 +79,7 @@ async def create_user(session, org_id: str):
|
|||
|
||||
async def generate_key(session, user_id: str, team_id: str):
|
||||
"""Helper function to generate a key for a specific user and team"""
|
||||
url = "http://0.0.0.0:4002/key/generate"
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {"user_id": user_id, "team_id": team_id}
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
|
@ -91,7 +91,7 @@ async def chat_completion(session, key: str):
|
|||
from openai import AsyncOpenAI
|
||||
import uuid
|
||||
|
||||
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4002/v1")
|
||||
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="fake-openai-endpoint",
|
||||
|
@ -102,7 +102,7 @@ async def chat_completion(session, key: str):
|
|||
|
||||
async def get_spend_info(session, entity_type: str, entity_id: str):
|
||||
"""Helper function to get spend information for an entity"""
|
||||
url = f"http://0.0.0.0:4002/{entity_type}/info"
|
||||
url = f"http://0.0.0.0:4000/{entity_type}/info"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
if entity_type == "key":
|
||||
data = {"key": entity_id}
|
||||
|
@ -156,7 +156,7 @@ async def test_basic_spend_accuracy():
|
|||
response = await chat_completion(session, key)
|
||||
print("response: ", response)
|
||||
|
||||
# wait 10 seconds for spend to be updated
|
||||
# wait 15 seconds for spend to be updated
|
||||
await asyncio.sleep(15)
|
||||
|
||||
# Get spend information for each entity
|
||||
|
@ -235,7 +235,7 @@ async def test_long_term_spend_accuracy_with_bursts():
|
|||
print(f"Burst 1 - Request {i+1}/{BURST_1_REQUESTS} completed")
|
||||
|
||||
# Wait for spend to be updated
|
||||
await asyncio.sleep(8)
|
||||
await asyncio.sleep(15)
|
||||
|
||||
# Check intermediate spend
|
||||
intermediate_key_info = await get_spend_info(session, "key", key)
|
||||
|
@ -248,7 +248,7 @@ async def test_long_term_spend_accuracy_with_bursts():
|
|||
print(f"Burst 2 - Request {i+1}/{BURST_2_REQUESTS} completed")
|
||||
|
||||
# Wait for spend to be updated
|
||||
await asyncio.sleep(8)
|
||||
await asyncio.sleep(15)
|
||||
|
||||
# Get final spend information for each entity
|
||||
key_info = await get_spend_info(session, "key", key)
|
Loading…
Add table
Add a link
Reference in a new issue