diff --git a/litellm/proxy/hooks/cache_control_check.py b/litellm/proxy/hooks/cache_control_check.py index 670e7554d..c50c4ec1f 100644 --- a/litellm/proxy/hooks/cache_control_check.py +++ b/litellm/proxy/hooks/cache_control_check.py @@ -10,7 +10,7 @@ from fastapi import HTTPException import json, traceback -class CacheControlCheck(CustomLogger): +class _PROXY_CacheControlCheck(CustomLogger): # Class variables or attributes def __init__(self): pass diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index fa24c9f0f..442cc53e3 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -7,7 +7,7 @@ from fastapi import HTTPException import json, traceback -class MaxBudgetLimiter(CustomLogger): +class _PROXY_MaxBudgetLimiter(CustomLogger): # Class variables or attributes def __init__(self): pass diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 5c1893ea5..ca60421a5 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -9,7 +9,7 @@ from litellm import ModelResponse from datetime import datetime -class MaxParallelRequestsHandler(CustomLogger): +class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_cache = None # Class variables or attributes diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6a38eb68a..fed61cf89 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -566,7 +566,6 @@ async def user_api_key_auth( and (not general_settings.get("allow_user_auth", False)) ): # enters this block when allow_user_auth is set to False - assert not general_settings.get("allow_user_auth", False) if route == "/key/info": # check if user can access this route query_params = request.query_params @@ -679,16 +678,17 @@ def cost_tracking(): if prisma_client is not None or custom_db_client is not None: if isinstance(litellm.success_callback, list): verbose_proxy_logger.debug("setting litellm success callback to track cost") - if (track_cost_callback) not in litellm.success_callback: # type: ignore - litellm.success_callback.append(track_cost_callback) # type: ignore + if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore + litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore -async def track_cost_callback( +async def _PROXY_track_cost_callback( kwargs, # kwargs to completion completion_response: litellm.ModelResponse, # response from completion start_time=None, end_time=None, # start/end time for completion ): + verbose_proxy_logger.debug(f"INSIDE _PROXY_track_cost_callback") global prisma_client, custom_db_client try: # check if it has collected an entire stream response @@ -752,8 +752,8 @@ async def update_database( end_time=None, ): try: - verbose_proxy_logger.debug( - f"Enters prisma db call, token: {token}; user_id: {user_id}" + verbose_proxy_logger.info( + f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}" ) ### UPDATE USER SPEND ### @@ -865,18 +865,16 @@ async def update_database( ) payload["spend"] = response_cost - if prisma_client is not None: await prisma_client.insert_data(data=payload, table_name="spend") elif custom_db_client is not None: await custom_db_client.insert_data(payload, table_name="spend") - tasks = [] - tasks.append(_update_user_db()) - tasks.append(_update_key_db()) - tasks.append(_insert_spend_log_to_db()) - await asyncio.gather(*tasks) + asyncio.create_task(_update_user_db()) + asyncio.create_task(_update_key_db()) + asyncio.create_task(_insert_spend_log_to_db()) + verbose_proxy_logger.info("Successfully updated spend in all 3 tables") except Exception as e: verbose_proxy_logger.debug( f"Error updating Prisma database: {traceback.format_exc()}" @@ -3934,7 +3932,7 @@ def _has_user_setup_sso(): """ microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - ui_username = os.getenv("UI_USERNAME") + ui_username = os.getenv("UI_USERNAME", None) sso_setup = ( (microsoft_client_id is not None) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index aff50b44d..905b9424e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -8,9 +8,11 @@ from litellm.proxy._types import ( LiteLLM_SpendLogs, ) from litellm.caching import DualCache -from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler -from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter -from litellm.proxy.hooks.cache_control_check import CacheControlCheck +from litellm.proxy.hooks.parallel_request_limiter import ( + _PROXY_MaxParallelRequestsHandler, +) +from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter +from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB from litellm._logging import verbose_proxy_logger @@ -41,9 +43,9 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.max_parallel_request_limiter = MaxParallelRequestsHandler() - self.max_budget_limiter = MaxBudgetLimiter() - self.cache_control_check = CacheControlCheck() + self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + self.max_budget_limiter = _PROXY_MaxBudgetLimiter() + self.cache_control_check = _PROXY_CacheControlCheck() self.alerting: Optional[List] = None self.alerting_threshold: float = 300 # default to 5 min. threshold pass @@ -522,7 +524,6 @@ class PrismaClient: response = await self.db.litellm_verificationtoken.find_many( order={"spend": "desc"}, ) - print_verbose(f"PrismaClient: response={response}") if response is not None: return response else: @@ -1200,8 +1201,6 @@ async def reset_budget(prisma_client: PrismaClient): table_name="user", query_type="find_all", reset_at=now ) - verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}") - if users_to_reset is not None and len(users_to_reset) > 0: for user in users_to_reset: user.spend = 0.0 diff --git a/litellm/tests/test_key_generate_dynamodb.py b/litellm/tests/test_key_generate_dynamodb.py index 418098333..61d0ff6a6 100644 --- a/litellm/tests/test_key_generate_dynamodb.py +++ b/litellm/tests/test_key_generate_dynamodb.py @@ -213,7 +213,9 @@ def test_call_with_user_over_budget(custom_db_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -244,7 +246,7 @@ def test_call_with_user_over_budget(custom_db_client): }, completion_response=resp, ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -285,7 +287,9 @@ def test_call_with_user_over_budget_stream(custom_db_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -317,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client): }, completion_response=ModelResponse(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -357,7 +361,9 @@ def test_call_with_user_key_budget(custom_db_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -388,7 +394,7 @@ def test_call_with_user_key_budget(custom_db_client): }, completion_response=resp, ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -429,7 +435,9 @@ def test_call_with_key_over_budget_stream(custom_db_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -461,7 +469,7 @@ def test_call_with_key_over_budget_stream(custom_db_client): }, completion_response=ModelResponse(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index d8ffcf022..19f4e008d 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -242,7 +242,9 @@ def test_call_with_user_over_budget(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -275,7 +277,7 @@ def test_call_with_user_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -323,7 +325,9 @@ def test_call_with_proxy_over_budget(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -356,7 +360,7 @@ def test_call_with_proxy_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -401,7 +405,9 @@ def test_call_with_user_over_budget_stream(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -435,7 +441,7 @@ def test_call_with_user_over_budget_stream(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -488,7 +494,9 @@ def test_call_with_proxy_over_budget_stream(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage resp = ModelResponse( @@ -522,7 +530,7 @@ def test_call_with_proxy_over_budget_stream(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -891,7 +899,9 @@ def test_call_with_key_over_budget(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage from litellm.caching import Cache @@ -931,7 +941,7 @@ def test_call_with_key_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(4) # test spend_log was written and we can read it spend_logs = await view_spend_logs(request_id=request_id) @@ -955,6 +965,8 @@ def test_call_with_key_over_budget(prisma_client): asyncio.run(test()) except Exception as e: + # print(f"Error - {str(e)}") + traceback.print_exc() error_detail = e.message assert "Authentication Error, ExceededTokenBudget:" in error_detail print(vars(e)) @@ -983,7 +995,9 @@ async def test_call_with_key_never_over_budget(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage import time @@ -1022,7 +1036,7 @@ async def test_call_with_key_never_over_budget(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) @@ -1058,7 +1072,9 @@ async def test_call_with_key_over_budget_stream(prisma_client): print("result from user auth with new key", result) # update spend using track_cost callback, make 2nd request, it should fail - from litellm.proxy.proxy_server import track_cost_callback + from litellm.proxy.proxy_server import ( + _PROXY_track_cost_callback as track_cost_callback, + ) from litellm import ModelResponse, Choices, Message, Usage import time @@ -1096,7 +1112,7 @@ async def test_call_with_key_over_budget_stream(prisma_client): start_time=datetime.now(), end_time=datetime.now(), ) - + await asyncio.sleep(5) # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index dee909eaa..1155e5794 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -18,7 +18,9 @@ from litellm import Router from litellm.proxy.utils import ProxyLogging from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler +from litellm.proxy.hooks.parallel_request_limiter import ( + _PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler, +) from datetime import datetime ## On Request received diff --git a/litellm/utils.py b/litellm/utils.py index e1caf5aef..d27c658c2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -774,14 +774,14 @@ class Logging: self.streaming_chunks = [] # for generating complete stream response self.sync_streaming_chunks = [] # for generating complete stream response self.model_call_details = {} - self.dynamic_input_callbacks = [] # callbacks set for just that call - self.dynamic_failure_callbacks = [] # callbacks set for just that call + self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call + self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call self.dynamic_success_callbacks = ( - dynamic_success_callbacks or [] - ) # callbacks set for just that call + dynamic_success_callbacks # callbacks set for just that call + ) self.dynamic_async_success_callbacks = ( - dynamic_async_success_callbacks or [] - ) # callbacks set for just that call + dynamic_async_success_callbacks # callbacks set for just that call + ) ## DYNAMIC LANGFUSE KEYS ## self.langfuse_public_key = langfuse_public_key self.langfuse_secret = langfuse_secret @@ -1145,7 +1145,19 @@ class Logging: f"Model={self.model} not found in completion cost map." ) self.model_call_details["response_cost"] = None - callbacks = litellm.success_callback + self.dynamic_success_callbacks + if self.dynamic_success_callbacks is not None and isinstance( + self.dynamic_success_callbacks, list + ): + callbacks = self.dynamic_success_callbacks + ## keep the internal functions ## + for callback in litellm.success_callback: + if ( + isinstance(callback, CustomLogger) + and "_PROXY_" in callback.__class__.__name__ + ): + callbacks.append(callback) + else: + callbacks = litellm.success_callback for callback in callbacks: try: if callback == "lite_debugger": @@ -1406,9 +1418,6 @@ class Logging: """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - verbose_logger.debug( - f"Async success callbacks: {litellm._async_success_callback}" - ) start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit ) @@ -1452,9 +1461,22 @@ class Logging: ) self.model_call_details["response_cost"] = None - callbacks = ( - litellm._async_success_callback + self.dynamic_async_success_callbacks - ) + if self.dynamic_async_success_callbacks is not None and isinstance( + self.dynamic_async_success_callbacks, list + ): + callbacks = self.dynamic_async_success_callbacks + ## keep the internal functions ## + for callback in litellm._async_success_callback: + callback_name = "" + if isinstance(callback, CustomLogger): + callback_name = callback.__class__.__name__ + if callable(callback): + callback_name = callback.__name__ + if "_PROXY_" in callback_name: + callbacks.append(callback) + else: + callbacks = litellm._async_success_callback + verbose_logger.debug(f"Async success callbacks: {callbacks}") for callback in callbacks: try: if callback == "cache" and litellm.cache is not None: @@ -1501,6 +1523,7 @@ class Logging: end_time=end_time, ) if callable(callback): # custom logger functions + print_verbose(f"Making async function logging call") if self.stream: if "complete_streaming_response" in self.model_call_details: await customLogger.async_log_event( @@ -1958,8 +1981,8 @@ def client(original_function): for index in reversed(removed_async_items): litellm.failure_callback.pop(index) ### DYNAMIC CALLBACKS ### - dynamic_success_callbacks = [] - dynamic_async_success_callbacks = [] + dynamic_success_callbacks = None + dynamic_async_success_callbacks = None if kwargs.get("success_callback", None) is not None and isinstance( kwargs["success_callback"], list ): @@ -1970,7 +1993,12 @@ def client(original_function): or callback == "dynamodb" or callback == "s3" ): - dynamic_async_success_callbacks.append(callback) + if dynamic_async_success_callbacks is not None and isinstance( + dynamic_async_success_callbacks, list + ): + dynamic_async_success_callbacks.append(callback) + else: + dynamic_async_success_callbacks = [callback] removed_async_items.append(index) # Pop the async items from success_callback in reverse order to avoid index issues for index in reversed(removed_async_items):