fix(dynamo_db.py): add cost tracking support for key + user

This commit is contained in:
Krrish Dholakia 2024-01-11 23:56:41 +05:30
parent 9b3d78c4f3
commit f94a37a836
4 changed files with 163 additions and 107 deletions

View file

@ -255,10 +255,10 @@ class PrismaClient:
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
## init logging object
self.proxy_logging_obj = proxy_logging_obj
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
original_dir = os.getcwd()
@ -334,7 +334,7 @@ class PrismaClient:
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
@ -582,47 +582,60 @@ class PrismaClient:
)
raise e
class DBClient:
"""
Routes requests for CustomAuth
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict) -> None:
def __init__(
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
) -> None:
if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, value=value, table_name=table_name)
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
return await self.db.get_data(key=key, table_name=table_name)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
key - hash_key value \n
value - dict with updated values
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
async def disconnect(self):
"""
For closing connection on server shutdown
"""
@ -669,7 +682,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
async def _cache_user_row(
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
):
"""
Check if a user_id exists in cache,
if not retrieve it.
@ -681,7 +696,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient
if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
user_row = await db.get_data(key=user_id, table_name="user")
if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable(