Merge pull request #1761 from BerriAI/litellm_fix_dynamic_callbacks

fix(utils.py): override default success callbacks with dynamic callbacks if set
This commit is contained in:
Krish Dholakia 2024-02-02 13:06:55 -08:00 committed by GitHub
commit 93fb0134e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 115 additions and 64 deletions

View file

@ -10,7 +10,7 @@ from fastapi import HTTPException
import json, traceback import json, traceback
class CacheControlCheck(CustomLogger): class _PROXY_CacheControlCheck(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass

View file

@ -7,7 +7,7 @@ from fastapi import HTTPException
import json, traceback import json, traceback
class MaxBudgetLimiter(CustomLogger): class _PROXY_MaxBudgetLimiter(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass

View file

@ -9,7 +9,7 @@ from litellm import ModelResponse
from datetime import datetime from datetime import datetime
class MaxParallelRequestsHandler(CustomLogger): class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None user_api_key_cache = None
# Class variables or attributes # Class variables or attributes

View file

@ -566,7 +566,6 @@ async def user_api_key_auth(
and (not general_settings.get("allow_user_auth", False)) and (not general_settings.get("allow_user_auth", False))
): ):
# enters this block when allow_user_auth is set to False # enters this block when allow_user_auth is set to False
assert not general_settings.get("allow_user_auth", False)
if route == "/key/info": if route == "/key/info":
# check if user can access this route # check if user can access this route
query_params = request.query_params query_params = request.query_params
@ -679,16 +678,17 @@ def cost_tracking():
if prisma_client is not None or custom_db_client is not None: if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost") verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(track_cost_callback) # type: ignore litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
async def track_cost_callback( async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion completion_response: litellm.ModelResponse, # response from completion
start_time=None, start_time=None,
end_time=None, # start/end time for completion end_time=None, # start/end time for completion
): ):
verbose_proxy_logger.debug(f"INSIDE _PROXY_track_cost_callback")
global prisma_client, custom_db_client global prisma_client, custom_db_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
@ -752,8 +752,8 @@ async def update_database(
end_time=None, end_time=None,
): ):
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.info(
f"Enters prisma db call, token: {token}; user_id: {user_id}" f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}"
) )
### UPDATE USER SPEND ### ### UPDATE USER SPEND ###
@ -865,18 +865,16 @@ async def update_database(
) )
payload["spend"] = response_cost payload["spend"] = response_cost
if prisma_client is not None: if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend") await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None: elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend") await custom_db_client.insert_data(payload, table_name="spend")
tasks = [] asyncio.create_task(_update_user_db())
tasks.append(_update_user_db()) asyncio.create_task(_update_key_db())
tasks.append(_update_key_db()) asyncio.create_task(_insert_spend_log_to_db())
tasks.append(_insert_spend_log_to_db()) verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Error updating Prisma database: {traceback.format_exc()}" f"Error updating Prisma database: {traceback.format_exc()}"
@ -3934,7 +3932,7 @@ def _has_user_setup_sso():
""" """
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
ui_username = os.getenv("UI_USERNAME") ui_username = os.getenv("UI_USERNAME", None)
sso_setup = ( sso_setup = (
(microsoft_client_id is not None) (microsoft_client_id is not None)

View file

@ -8,9 +8,11 @@ from litellm.proxy._types import (
LiteLLM_SpendLogs, LiteLLM_SpendLogs,
) )
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import (
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter _PROXY_MaxParallelRequestsHandler,
from litellm.proxy.hooks.cache_control_check import CacheControlCheck )
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -41,9 +43,9 @@ class ProxyLogging:
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter() self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
self.cache_control_check = CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
pass pass
@ -522,7 +524,6 @@ class PrismaClient:
response = await self.db.litellm_verificationtoken.find_many( response = await self.db.litellm_verificationtoken.find_many(
order={"spend": "desc"}, order={"spend": "desc"},
) )
print_verbose(f"PrismaClient: response={response}")
if response is not None: if response is not None:
return response return response
else: else:
@ -1200,8 +1201,6 @@ async def reset_budget(prisma_client: PrismaClient):
table_name="user", query_type="find_all", reset_at=now table_name="user", query_type="find_all", reset_at=now
) )
verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}")
if users_to_reset is not None and len(users_to_reset) > 0: if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset: for user in users_to_reset:
user.spend = 0.0 user.spend = 0.0

View file

@ -213,7 +213,9 @@ def test_call_with_user_over_budget(custom_db_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -244,7 +246,7 @@ def test_call_with_user_over_budget(custom_db_client):
}, },
completion_response=resp, completion_response=resp,
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -285,7 +287,9 @@ def test_call_with_user_over_budget_stream(custom_db_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -317,7 +321,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
}, },
completion_response=ModelResponse(), completion_response=ModelResponse(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -357,7 +361,9 @@ def test_call_with_user_key_budget(custom_db_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -388,7 +394,7 @@ def test_call_with_user_key_budget(custom_db_client):
}, },
completion_response=resp, completion_response=resp,
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -429,7 +435,9 @@ def test_call_with_key_over_budget_stream(custom_db_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -461,7 +469,7 @@ def test_call_with_key_over_budget_stream(custom_db_client):
}, },
completion_response=ModelResponse(), completion_response=ModelResponse(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)

View file

@ -242,7 +242,9 @@ def test_call_with_user_over_budget(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -275,7 +277,7 @@ def test_call_with_user_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -323,7 +325,9 @@ def test_call_with_proxy_over_budget(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -356,7 +360,7 @@ def test_call_with_proxy_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -401,7 +405,9 @@ def test_call_with_user_over_budget_stream(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -435,7 +441,7 @@ def test_call_with_user_over_budget_stream(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -488,7 +494,9 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse( resp = ModelResponse(
@ -522,7 +530,7 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -891,7 +899,9 @@ def test_call_with_key_over_budget(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
from litellm.caching import Cache from litellm.caching import Cache
@ -931,7 +941,7 @@ def test_call_with_key_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(4)
# test spend_log was written and we can read it # test spend_log was written and we can read it
spend_logs = await view_spend_logs(request_id=request_id) spend_logs = await view_spend_logs(request_id=request_id)
@ -955,6 +965,8 @@ def test_call_with_key_over_budget(prisma_client):
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
# print(f"Error - {str(e)}")
traceback.print_exc()
error_detail = e.message error_detail = e.message
assert "Authentication Error, ExceededTokenBudget:" in error_detail assert "Authentication Error, ExceededTokenBudget:" in error_detail
print(vars(e)) print(vars(e))
@ -983,7 +995,9 @@ async def test_call_with_key_never_over_budget(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
import time import time
@ -1022,7 +1036,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
@ -1058,7 +1072,9 @@ async def test_call_with_key_over_budget_stream(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback from litellm.proxy.proxy_server import (
_PROXY_track_cost_callback as track_cost_callback,
)
from litellm import ModelResponse, Choices, Message, Usage from litellm import ModelResponse, Choices, Message, Usage
import time import time
@ -1096,7 +1112,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
start_time=datetime.now(), start_time=datetime.now(),
end_time=datetime.now(), end_time=datetime.now(),
) )
await asyncio.sleep(5)
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)

View file

@ -18,7 +18,9 @@ from litellm import Router
from litellm.proxy.utils import ProxyLogging from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
)
from datetime import datetime from datetime import datetime
## On Request received ## On Request received

View file

@ -774,14 +774,14 @@ class Logging:
self.streaming_chunks = [] # for generating complete stream response self.streaming_chunks = [] # for generating complete stream response
self.sync_streaming_chunks = [] # for generating complete stream response self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {} self.model_call_details = {}
self.dynamic_input_callbacks = [] # callbacks set for just that call self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
self.dynamic_failure_callbacks = [] # callbacks set for just that call self.dynamic_failure_callbacks = [] # [TODO] callbacks set for just that call
self.dynamic_success_callbacks = ( self.dynamic_success_callbacks = (
dynamic_success_callbacks or [] dynamic_success_callbacks # callbacks set for just that call
) # callbacks set for just that call )
self.dynamic_async_success_callbacks = ( self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks or [] dynamic_async_success_callbacks # callbacks set for just that call
) # callbacks set for just that call )
## DYNAMIC LANGFUSE KEYS ## ## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret self.langfuse_secret = langfuse_secret
@ -1145,7 +1145,19 @@ class Logging:
f"Model={self.model} not found in completion cost map." f"Model={self.model} not found in completion cost map."
) )
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
callbacks = litellm.success_callback + self.dynamic_success_callbacks if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
callbacks = self.dynamic_success_callbacks
## keep the internal functions ##
for callback in litellm.success_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.success_callback
for callback in callbacks: for callback in callbacks:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
@ -1406,9 +1418,6 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -1452,9 +1461,22 @@ class Logging:
) )
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
callbacks = ( if self.dynamic_async_success_callbacks is not None and isinstance(
litellm._async_success_callback + self.dynamic_async_success_callbacks self.dynamic_async_success_callbacks, list
) ):
callbacks = self.dynamic_async_success_callbacks
## keep the internal functions ##
for callback in litellm._async_success_callback:
callback_name = ""
if isinstance(callback, CustomLogger):
callback_name = callback.__class__.__name__
if callable(callback):
callback_name = callback.__name__
if "_PROXY_" in callback_name:
callbacks.append(callback)
else:
callbacks = litellm._async_success_callback
verbose_logger.debug(f"Async success callbacks: {callbacks}")
for callback in callbacks: for callback in callbacks:
try: try:
if callback == "cache" and litellm.cache is not None: if callback == "cache" and litellm.cache is not None:
@ -1501,6 +1523,7 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
print_verbose(f"Making async function logging call")
if self.stream: if self.stream:
if "complete_streaming_response" in self.model_call_details: if "complete_streaming_response" in self.model_call_details:
await customLogger.async_log_event( await customLogger.async_log_event(
@ -1958,8 +1981,8 @@ def client(original_function):
for index in reversed(removed_async_items): for index in reversed(removed_async_items):
litellm.failure_callback.pop(index) litellm.failure_callback.pop(index)
### DYNAMIC CALLBACKS ### ### DYNAMIC CALLBACKS ###
dynamic_success_callbacks = [] dynamic_success_callbacks = None
dynamic_async_success_callbacks = [] dynamic_async_success_callbacks = None
if kwargs.get("success_callback", None) is not None and isinstance( if kwargs.get("success_callback", None) is not None and isinstance(
kwargs["success_callback"], list kwargs["success_callback"], list
): ):
@ -1969,8 +1992,13 @@ def client(original_function):
inspect.iscoroutinefunction(callback) inspect.iscoroutinefunction(callback)
or callback == "dynamodb" or callback == "dynamodb"
or callback == "s3" or callback == "s3"
):
if dynamic_async_success_callbacks is not None and isinstance(
dynamic_async_success_callbacks, list
): ):
dynamic_async_success_callbacks.append(callback) dynamic_async_success_callbacks.append(callback)
else:
dynamic_async_success_callbacks = [callback]
removed_async_items.append(index) removed_async_items.append(index)
# Pop the async items from success_callback in reverse order to avoid index issues # Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items): for index in reversed(removed_async_items):