mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(dynamo_db.py): add cost tracking support for key + user
This commit is contained in:
parent
9b3d78c4f3
commit
f94a37a836
4 changed files with 163 additions and 107 deletions
|
@ -255,10 +255,10 @@ class PrismaClient:
|
|||
print_verbose(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
|
||||
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
# Save the current working directory
|
||||
original_dir = os.getcwd()
|
||||
|
@ -334,7 +334,7 @@ class PrismaClient:
|
|||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
|
@ -582,47 +582,60 @@ class PrismaClient:
|
|||
)
|
||||
raise e
|
||||
|
||||
|
||||
class DBClient:
|
||||
"""
|
||||
Routes requests for CustomAuth
|
||||
Routes requests for CustomAuth
|
||||
|
||||
[TODO] route b/w customauth and prisma
|
||||
"""
|
||||
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict) -> None:
|
||||
|
||||
def __init__(
|
||||
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
|
||||
) -> None:
|
||||
if custom_db_type == "dynamo_db":
|
||||
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
|
||||
|
||||
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
Check if key valid
|
||||
"""
|
||||
return await self.db.get_data(key=key, value=value, table_name=table_name)
|
||||
|
||||
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
return await self.db.get_data(key=key, table_name=table_name)
|
||||
|
||||
async def insert_data(
|
||||
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For new key / user logic
|
||||
"""
|
||||
return await self.db.insert_data(value=value, table_name=table_name)
|
||||
|
||||
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
async def update_data(
|
||||
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For cost tracking logic
|
||||
|
||||
key - hash_key value \n
|
||||
value - dict with updated values
|
||||
"""
|
||||
return await self.db.update_data(key=key, value=value, table_name=table_name)
|
||||
|
||||
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
|
||||
async def delete_data(
|
||||
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For /key/delete endpoints
|
||||
"""
|
||||
return await self.db.delete_data(keys=keys, table_name=table_name)
|
||||
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
For connecting to db and creating / updating any tables
|
||||
"""
|
||||
For connecting to db and creating / updating any tables
|
||||
"""
|
||||
return await self.db.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self):
|
||||
"""
|
||||
For closing connection on server shutdown
|
||||
"""
|
||||
|
@ -669,7 +682,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
|
||||
|
||||
### HELPER FUNCTIONS ###
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
|
||||
async def _cache_user_row(
|
||||
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
|
||||
):
|
||||
"""
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
|
@ -681,7 +696,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient
|
|||
if isinstance(db, PrismaClient):
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
elif isinstance(db, DBClient):
|
||||
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
|
||||
user_row = await db.get_data(key=user_id, table_name="user")
|
||||
if user_row is not None:
|
||||
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
||||
if hasattr(user_row, "model_dump_json") and callable(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue