Merge pull request #1761 from BerriAI/litellm_fix_dynamic_callbacks

fix(utils.py): override default success callbacks with dynamic callbacks if set
This commit is contained in:
Krish Dholakia 2024-02-02 13:06:55 -08:00 committed by GitHub
commit 93fb0134e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 115 additions and 64 deletions

View file

@ -566,7 +566,6 @@ async def user_api_key_auth(
and (not general_settings.get("allow_user_auth", False))
):
# enters this block when allow_user_auth is set to False
assert not general_settings.get("allow_user_auth", False)
if route == "/key/info":
# check if user can access this route
query_params = request.query_params
@ -679,16 +678,17 @@ def cost_tracking():
if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(track_cost_callback) # type: ignore
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore
async def track_cost_callback(
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
verbose_proxy_logger.debug(f"INSIDE _PROXY_track_cost_callback")
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
@ -752,8 +752,8 @@ async def update_database(
end_time=None,
):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
verbose_proxy_logger.info(
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}"
)
### UPDATE USER SPEND ###
@ -865,18 +865,16 @@ async def update_database(
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
asyncio.create_task(_update_user_db())
asyncio.create_task(_update_key_db())
asyncio.create_task(_insert_spend_log_to_db())
verbose_proxy_logger.info("Successfully updated spend in all 3 tables")
except Exception as e:
verbose_proxy_logger.debug(
f"Error updating Prisma database: {traceback.format_exc()}"
@ -3934,7 +3932,7 @@ def _has_user_setup_sso():
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
ui_username = os.getenv("UI_USERNAME")
ui_username = os.getenv("UI_USERNAME", None)
sso_setup = (
(microsoft_client_id is not None)