Merge pull request #1761 from BerriAI/litellm_fix_dynamic_callbacks

fix(utils.py): override default success callbacks with dynamic callbacks if set
This commit is contained in:
Krish Dholakia 2024-02-02 13:06:55 -08:00 committed by GitHub
commit 93fb0134e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 115 additions and 64 deletions

View file

@ -10,7 +10,7 @@ from fastapi import HTTPException
import json, traceback
class CacheControlCheck(CustomLogger):
class _PROXY_CacheControlCheck(CustomLogger):
# Class variables or attributes
def __init__(self):
pass

View file

@ -7,7 +7,7 @@ from fastapi import HTTPException
import json, traceback
class MaxBudgetLimiter(CustomLogger):
class _PROXY_MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass

View file

@ -9,7 +9,7 @@ from litellm import ModelResponse
from datetime import datetime
class MaxParallelRequestsHandler(CustomLogger):
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes

View file

@ -566,7 +566,6 @@ async def user_api_key_auth(
and (not general_settings.get("allow_user_auth", False))
):
# enters this block when allow_user_auth is set to False
assert not general_settings.get("allow_user_auth", False)
if route == "/key/info":
# check if user can access this route
query_params = request.query_params
@ -679,16 +678,17 @@ def cost_tracking():
if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(track_cost_callback) # type: ignore
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
async def track_cost_callback(
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
verbose_proxy_logger.debug(f"INSIDE _PROXY_track_cost_callback")
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
@ -752,8 +752,8 @@ async def update_database(
end_time=None,
):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
verbose_proxy_logger.info(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}"
)
### UPDATE USER SPEND ###
@ -865,18 +865,16 @@ async def update_database(
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_insert_spend_log_to_db())
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
except Exception as e:
verbose_proxy_logger.debug(
f"Error updating Prisma database: {traceback.format_exc()}"
@ -3934,7 +3932,7 @@ def _has_user_setup_sso():
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
ui_username = os.getenv("UI_USERNAME")
ui_username = os.getenv("UI_USERNAME", None)
sso_setup = (
(microsoft_client_id is not None)

View file

@ -8,9 +8,11 @@ from litellm.proxy._types import (
LiteLLM_SpendLogs,
)
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import CacheControlCheck
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
@ -41,9 +43,9 @@ class ProxyLogging:
## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
self.cache_control_check = CacheControlCheck()
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
pass
@ -522,7 +524,6 @@ class PrismaClient:
response = await self.db.litellm_verificationtoken.find_many(
order={"spend": "desc"},
)
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
else:
@ -1200,8 +1201,6 @@ async def reset_budget(prisma_client: PrismaClient):
table_name="user", query_type="find_all", reset_at=now
)
verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}")
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
user.spend = 0.0

View file

@ -213,7 +213,9 @@ def test_call_with_user_over_budget(custom_db_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -244,7 +246,7 @@ def test_call_with_user_over_budget(custom_db_client):
},
completion_response=resp,
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -285,7 +287,9 @@ def test_call_with_user_over_budget_stream(custom_db_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -317,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
},
completion_response=ModelResponse(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -357,7 +361,9 @@ def test_call_with_user_key_budget(custom_db_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -388,7 +394,7 @@ def test_call_with_user_key_budget(custom_db_client):
},
completion_response=resp,
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -429,7 +435,9 @@ def test_call_with_key_over_budget_stream(custom_db_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -461,7 +469,7 @@ def test_call_with_key_over_budget_stream(custom_db_client):
},
completion_response=ModelResponse(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)

View file

@ -242,7 +242,9 @@ def test_call_with_user_over_budget(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -275,7 +277,7 @@ def test_call_with_user_over_budget(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -323,7 +325,9 @@ def test_call_with_proxy_over_budget(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -356,7 +360,7 @@ def test_call_with_proxy_over_budget(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -401,7 +405,9 @@ def test_call_with_user_over_budget_stream(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -435,7 +441,7 @@ def test_call_with_user_over_budget_stream(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -488,7 +494,9 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
@ -522,7 +530,7 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -891,7 +899,9 @@ def test_call_with_key_over_budget(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
from litellm.caching import Cache
@ -931,7 +941,7 @@ def test_call_with_key_over_budget(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(4)
# test spend_log was written and we can read it
spend_logs = await view_spend_logs(request_id=request_id)
@ -955,6 +965,8 @@ def test_call_with_key_over_budget(prisma_client):
asyncio.run(test())
except Exception as e:
# print(f"Error - {str(e)}")
traceback.print_exc()
error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e))
@ -983,7 +995,9 @@ async def test_call_with_key_never_over_budget(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
import time
@ -1022,7 +1036,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
@ -1058,7 +1072,9 @@ async def test_call_with_key_over_budget_stream(prisma_client):
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage
import time
@ -1096,7 +1112,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
start_time=datetime.now(),
end_time=datetime.now(),
)
await asyncio.sleep(5)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)

View file

@ -18,7 +18,9 @@ from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
)
from datetime import datetime
## On Request received

View file

@ -774,14 +774,14 @@ class Logging:
self.streaming_chunks = [] # for generating complete stream response
self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {}
self.dynamic_input_callbacks = [] # callbacks set for just that call
self.dynamic_failure_callbacks = [] # callbacks set for just that call
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call
self.dynamic_success_callbacks = (
dynamic_success_callbacks or []
) # callbacks set for just that call
dynamic_success_callbacks # callbacks set for just that call
)
self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks or []
) # callbacks set for just that call
dynamic_async_success_callbacks # callbacks set for just that call
)
## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret
@ -1145,7 +1145,19 @@ class Logging:
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
callbacks = litellm.success_callback + self.dynamic_success_callbacks
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
callbacks = self.dynamic_success_callbacks
## keep the internal functions ##
for callback in litellm.success_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.success_callback
for callback in callbacks:
try:
if callback == "lite_debugger":
@ -1406,9 +1418,6 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -1452,9 +1461,22 @@ class Logging:
)
self.model_call_details["response_cost"] = None
callbacks = (
litellm._async_success_callback + self.dynamic_async_success_callbacks
)
if self.dynamic_async_success_callbacks is not None and isinstance(
self.dynamic_async_success_callbacks, list
):
callbacks = self.dynamic_async_success_callbacks
## keep the internal functions ##
for callback in litellm._async_success_callback:
callback_name = ""
if isinstance(callback, CustomLogger):
callback_name = callback.__class__.__name__
if callable(callback):
callback_name = callback.__name__
if "_PROXY_" in callback_name:
callbacks.append(callback)
else:
callbacks = litellm._async_success_callback
verbose_logger.debug(f"Async success callbacks: {callbacks}")
for callback in callbacks:
try:
if callback == "cache" and litellm.cache is not None:
@ -1501,6 +1523,7 @@ class Logging:
end_time=end_time,
)
if callable(callback): # custom logger functions
print_verbose(f"Making async function logging call")
if self.stream:
if "complete_streaming_response" in self.model_call_details:
await customLogger.async_log_event(
@ -1958,8 +1981,8 @@ def client(original_function):
for index in reversed(removed_async_items):
litellm.failure_callback.pop(index)
### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = []
dynamic_async_success_callbacks = []
dynamic_success_callbacks = None
dynamic_async_success_callbacks = None
if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list
):
@ -1970,7 +1993,12 @@ def client(original_function):
or callback == "dynamodb"
or callback == "s3"
):
dynamic_async_success_callbacks.append(callback)
if dynamic_async_success_callbacks is not None and isinstance(
dynamic_async_success_callbacks, list
):
dynamic_async_success_callbacks.append(callback)
else:
dynamic_async_success_callbacks = [callback]
removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):