mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): adds working dynamo db support for key gen
This commit is contained in:
parent
4cfa010dbd
commit
35f9666dc2
5 changed files with 362 additions and 30 deletions
|
@ -1,11 +1,13 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
from typing import Optional, List, Any, Literal, Union
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
|
||||
from fastapi import HTTPException, status
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
@ -332,7 +334,7 @@ class PrismaClient:
|
|||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
|
@ -580,6 +582,52 @@ class PrismaClient:
|
|||
)
|
||||
raise e
|
||||
|
||||
class DBClient:
|
||||
"""
|
||||
Routes requests for CustomAuth
|
||||
|
||||
[TODO] route b/w customauth and prisma
|
||||
"""
|
||||
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: Optional[dict]=None) -> None:
|
||||
if custom_db_type == "dynamo_db":
|
||||
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
|
||||
|
||||
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
Check if key valid
|
||||
"""
|
||||
return await self.db.get_data(key=key, value=value, table_name=table_name)
|
||||
|
||||
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
For new key / user logic
|
||||
"""
|
||||
return await self.db.insert_data(value=value, table_name=table_name)
|
||||
|
||||
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
For cost tracking logic
|
||||
"""
|
||||
return await self.db.update_data(key=key, value=value, table_name=table_name)
|
||||
|
||||
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
For /key/delete endpoints
|
||||
"""
|
||||
return await self.db.delete_data(keys=keys, table_name=table_name)
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
For connecting to db and creating / updating any tables
|
||||
"""
|
||||
return await self.db.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
"""
|
||||
For closing connection on server shutdown
|
||||
"""
|
||||
return await self.db.disconnect()
|
||||
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
|
@ -621,7 +669,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
|
||||
|
||||
### HELPER FUNCTIONS ###
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
|
||||
"""
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
|
@ -630,7 +678,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
|||
cache_key = f"{user_id}_user_api_key_user_id"
|
||||
response = cache.get_cache(key=cache_key)
|
||||
if response is None: # Cache miss
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
if isinstance(db, PrismaClient):
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
elif isinstance(db, DBClient):
|
||||
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
|
||||
if user_row is not None:
|
||||
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
||||
if hasattr(user_row, "model_dump_json") and callable(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue