(v0) fixes

This commit is contained in:
ishaan-jaff 2024-03-05 15:27:06 -08:00
parent 47e1ee74f5
commit fabde529fa
2 changed files with 40 additions and 16 deletions

View file

@ -9,6 +9,10 @@ import warnings
import importlib import importlib
import warnings import warnings
import logging
logging.getLogger("prisma").setLevel(logging.DEBUG)
def showwarning(message, category, filename, lineno, file=None, line=None): def showwarning(message, category, filename, lineno, file=None, line=None):
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n" traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
@ -1138,6 +1142,7 @@ async def update_database(
) )
# set cooldown on alert # set cooldown on alert
soft_budget_cooldown = True soft_budget_cooldown = True
# track cost per model, for the given key # track cost per model, for the given key
spend_per_model = existing_spend_obj.model_spend or {} spend_per_model = existing_spend_obj.model_spend or {}
current_model = kwargs.get("model") current_model = kwargs.get("model")
@ -1153,11 +1158,7 @@ async def update_database(
# Update the cost column for the given token # Update the cost column for the given token
await prisma_client.update_data( await prisma_client.update_data(
token=token, token=token,
data={ data={"spend": new_spend, "model_spend": spend_per_model},
"spend": new_spend,
"model_spend": spend_per_model,
"soft_budget_cooldown": soft_budget_cooldown,
},
) )
valid_token = user_api_key_cache.get_cache(key=token) valid_token = user_api_key_cache.get_cache(key=token)
@ -1211,9 +1212,9 @@ async def update_database(
payload["spend"] = response_cost payload["spend"] = response_cost
if prisma_client is not None: if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend") await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None: elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend") await custom_db_client.insert_data(payload, table_name="spend")
except Exception as e: except Exception as e:
verbose_proxy_logger.info(f"Update Spend Logs DB failed to execute") verbose_proxy_logger.info(f"Update Spend Logs DB failed to execute")

View file

@ -64,7 +64,7 @@ class ProxyLogging:
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check) litellm.callbacks.append(self.cache_control_check)
litellm.callbacks.append(self.response_taking_too_long_callback) # litellm.callbacks.append(self.response_taking_too_long_callback)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
@ -362,7 +362,7 @@ class ProxyLogging:
else: else:
raise Exception("Missing SENTRY_DSN from environment") raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception): async def failure_handler(self, original_exception, traceback_str=""):
""" """
Log failed db read/writes Log failed db read/writes
@ -373,6 +373,7 @@ class ProxyLogging:
error_message = original_exception.detail error_message = original_exception.detail
else: else:
error_message = str(original_exception) error_message = str(original_exception)
error_message += traceback_str
asyncio.create_task( asyncio.create_task(
self.alerting_handler( self.alerting_handler(
message=f"DB read/write call failed: {error_message}", message=f"DB read/write call failed: {error_message}",
@ -706,8 +707,13 @@ class PrismaClient:
) )
return response return response
except Exception as e: except Exception as e:
import traceback
tracback_str = traceback.format_exc()
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=tracback_str
)
) )
raise e raise e
@ -912,9 +918,10 @@ class PrismaClient:
return response return response
elif table_name == "team": elif table_name == "team":
if query_type == "find_unique": if query_type == "find_unique":
response = await self.db.litellm_teamtable.find_unique( response = None
where={"team_id": team_id} # type: ignore # response = await self.db.litellm_teamtable.find_unique(
) # where={"team_id": team_id} # type: ignore
# )
elif query_type == "find_all" and user_id is not None: elif query_type == "find_all" and user_id is not None:
response = await self.db.litellm_teamtable.find_many( response = await self.db.litellm_teamtable.find_many(
where={ where={
@ -971,8 +978,12 @@ class PrismaClient:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# get tracback
traceback_string = traceback.format_exc()
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=traceback_string
)
) )
raise e raise e
@ -1093,8 +1104,12 @@ class PrismaClient:
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}") print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
traceback_str = traceback.format_exc()
print_verbose(f"Traceback: {traceback_str}")
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e, traceback_str=traceback_str) # type: ignore # noqa=traceback_str)
) )
raise e raise e
@ -1277,8 +1292,12 @@ class PrismaClient:
"\033[91m" + f"DB User Table Batch update succeeded" + "\033[0m" "\033[91m" + f"DB User Table Batch update succeeded" + "\033[0m"
) )
except Exception as e: except Exception as e:
import traceback
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=traceback.format_exc()
)
) )
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e raise e
@ -1331,8 +1350,12 @@ class PrismaClient:
where={"team_id": {"in": team_id_list}} where={"team_id": {"in": team_id_list}}
) )
except Exception as e: except Exception as e:
import traceback
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=traceback.format_exc()
)
) )
raise e raise e