feat(proxy_server.py): adds working dynamo db support for key gen

This commit is contained in:
Krrish Dholakia 2024-01-09 18:11:24 +05:30
parent 4cfa010dbd
commit 35f9666dc2
5 changed files with 362 additions and 30 deletions

View file

@ -1,11 +1,13 @@
from typing import Optional, List, Any, Literal
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
from fastapi import HTTPException, status
import smtplib
from email.mime.text import MIMEText
@ -332,7 +334,7 @@ class PrismaClient:
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
@ -580,6 +582,52 @@ class PrismaClient:
)
raise e
class DBClient:
"""
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: Optional[dict]=None) -> None:
if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, value=value, table_name=table_name)
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
"""
For cost tracking logic
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
@ -621,7 +669,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
"""
Check if a user_id exists in cache,
if not retrieve it.
@ -630,7 +678,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key)
if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id)
if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable(