feat(proxy_server.py): adds working dynamo db support for key gen

This commit is contained in:
Krrish Dholakia 2024-01-09 18:11:24 +05:30
parent 4cfa010dbd
commit 35f9666dc2
5 changed files with 362 additions and 30 deletions

View file

@ -68,6 +68,7 @@ def generate_feedback_box():
import litellm
from litellm.proxy.utils import (
PrismaClient,
DBClient,
get_instance_fn,
ProxyLogging,
_cache_user_row,
@ -142,6 +143,7 @@ worker_config = None
master_key = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
use_background_health_checks = None
@ -185,12 +187,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try:
if isinstance(api_key, str):
api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
@ -232,7 +234,7 @@ async def user_api_key_auth(
)
if (
prisma_client is None
prisma_client is None and custom_db_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
@ -242,9 +244,24 @@ async def user_api_key_auth(
if valid_token is None:
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
)
if prisma_client is not None:
valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key")
# Token exists, now check expiration.
if valid_token.expires is not None:
expiry_time = datetime.fromisoformat(valid_token.expires)
if expiry_time >= datetime.utcnow():
# Token exists and is not expired.
return response
else:
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
@ -280,13 +297,22 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks.
"""
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=prisma_client,
if prisma_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=prisma_client,
)
)
elif custom_db_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=custom_db_client,
)
)
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
@ -611,7 +637,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -785,11 +811,17 @@ class ProxyConfig:
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
## pass filepath
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
if custom_auth is not None:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and database_type == "dynamo_db":
database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
@ -856,9 +888,9 @@ async def generate_key_helper_fn(
max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {},
):
global prisma_client
global prisma_client, custom_db_client
if prisma_client is None:
if prisma_client is None and custom_db_client is None:
raise Exception(
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
)
@ -897,7 +929,12 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4())
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id
}
key_data = {
"token": token,
"expires": expires,
"models": models,
@ -907,19 +944,24 @@ async def generate_key_helper_fn(
"user_id": user_id,
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"max_budget": max_budget,
"user_email": user_email,
}
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
new_verification_token = await prisma_client.insert_data(
data=verification_token_data
)
if prisma_client is not None:
verification_token_data = key_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
await prisma_client.insert_data(
data=verification_token_data
)
elif custom_db_client is not None:
## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {
"token": token,
"expires": new_verification_token.expires,
"expires": expires,
"user_id": user_id,
"max_budget": max_budget,
}
@ -1187,13 +1229,25 @@ async def startup_event():
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
if prisma_client is not None:
await prisma_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if prisma_client is not None and master_key is not None:
# add master key to db
print(f'prisma_client: {prisma_client}')
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
if custom_db_client is not None and master_key is not None:
# add master key to db
print(f'custom_db_client: {custom_db_client}')
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
#### API ENDPOINTS ####
@router.get(