feat(proxy_server.py): adds working dynamo db support for key gen

This commit is contained in:
Krrish Dholakia 2024-01-09 18:11:24 +05:30
parent b09f38e835
commit 00c258c165
5 changed files with 362 additions and 30 deletions

View file

@ -4,7 +4,6 @@ from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
@ -181,6 +180,12 @@ class KeyManagementSystem(enum.Enum):
AZURE_KEY_VAULT = "azure_key_vault" AZURE_KEY_VAULT = "azure_key_vault"
LOCAL = "local" LOCAL = "local"
class DynamoDBArgs(LiteLLMBase):
billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"]
read_capacity_units: Optional[int] = None
write_capacity_units: Optional[int] = None
region_name: Optional[str] = None
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
@ -206,6 +211,8 @@ class ConfigGeneralSettings(LiteLLMBase):
None, None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
) )
database_type: Optional[Literal["dynamo_db"]] = Field(None, description="to use dynamodb instead of postgres db")
database_args: Optional[DynamoDBArgs] = Field(None, description="custom args for instantiating dynamodb client - e.g. billing provision")
otel: Optional[bool] = Field( otel: Optional[bool] = Field(
None, None,
description="[BETA] OpenTelemetry support - this might change, use with caution.", description="[BETA] OpenTelemetry support - this might change, use with caution.",
@ -258,3 +265,35 @@ class ConfigYAML(LiteLLMBase):
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
class DBTableNames(enum.Enum):
user = "LiteLLM_UserTable"
key = "LiteLLM_VerificationToken"
config = "LiteLLM_Config"
class LiteLLM_VerificationToken(LiteLLMBase):
token: str
spend: float = 0.0
expires: Union[str, None]
models: List[str]
aliases: Dict[str, str] = {}
config: Dict[str, str] = {}
user_id: Union[str, None]
max_parallel_requests: Union[int, None]
metadata: Dict[str, str] = {}
class LiteLLM_Config(LiteLLMBase):
param_name: str
param_value: Dict
class LiteLLM_UserTable(LiteLLMBase):
user_id: str
max_budget: Optional[float]
spend: float = 0.0
user_email: Optional[str]
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values

View file

@ -0,0 +1,43 @@
from typing import Any, Literal, List
class CustomDB:
"""
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
"""
def __init__(self) -> None:
pass
def get_data(self, key: str, value: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
pass
def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
"""
For cost tracking logic
"""
pass
def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
"""
For /key/delete endpoint s
"""
def connect(self, ):
"""
For connecting to db and creating / updating any tables
"""
pass
def disconnect(self, ):
"""
For closing connection on server shutdown
"""
pass

View file

@ -0,0 +1,145 @@
import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
from litellm import get_secret
from typing import Any, List, Literal, Optional
from aiodynamo.expressions import UpdateExpression, F
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units)
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
table = client.table(DBTableNames.user.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
## Token
table = client.table(DBTableNames.key.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
## Config
table = client.table(DBTableNames.config.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value})
if table_name == DBTableNames.user.name:
new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name:
new_response = {}
for k, v in response.items(): # handle json string
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
data_obj = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
key_name = "user_id"
data_obj = LiteLLM_UserTable(user_id=key, **value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
key_name = "token"
data_obj = LiteLLM_VerificationToken(token=key, **value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value)
# Initialize an empty UpdateExpression
update_expression = UpdateExpression()
# Add updates for each field that has been modified
for field in data_obj.model_fields_set:
# If a Pydantic model has a __fields_set__ attribute, it's a set of fields that were set when the model was instantiated
field_value = getattr(data_obj, field)
if isinstance(field_value, datetime):
field_value = field_value.isoformat()
update_expression = update_expression.set(F(field), field_value)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.NONE
)
return result
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)

View file

@ -68,6 +68,7 @@ def generate_feedback_box():
import litellm import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
DBClient,
get_instance_fn, get_instance_fn,
ProxyLogging, ProxyLogging,
_cache_user_row, _cache_user_row,
@ -142,6 +143,7 @@ worker_config = None
master_key = None master_key = None
otel_logging = False otel_logging = False
prisma_client: Optional[PrismaClient] = None prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
user_custom_auth = None user_custom_auth = None
use_background_health_checks = None use_background_health_checks = None
@ -185,12 +187,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
async def user_api_key_auth( async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header) request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try: try:
if isinstance(api_key, str): if isinstance(api_key, str):
api_key = _get_bearer_token(api_key=api_key) api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ### ### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth: if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ### ### LITELLM-DEFINED AUTH FUNCTION ###
@ -232,7 +234,7 @@ async def user_api_key_auth(
) )
if ( if (
prisma_client is None prisma_client is None and custom_db_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") raise Exception("No connected db.")
@ -242,9 +244,24 @@ async def user_api_key_auth(
if valid_token is None: if valid_token is None:
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None:
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
) )
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key")
# Token exists, now check expiration.
if valid_token.expires is not None:
expiry_time = datetime.fromisoformat(valid_token.expires)
if expiry_time >= datetime.utcnow():
# Token exists and is not expired.
return response
else:
# Token exists but is expired.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}") verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}")
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
@ -280,6 +297,7 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks. This makes the user row data accessible to pre-api call hooks.
""" """
if prisma_client is not None:
asyncio.create_task( asyncio.create_task(
_cache_user_row( _cache_user_row(
user_id=valid_token.user_id, user_id=valid_token.user_id,
@ -287,6 +305,14 @@ async def user_api_key_auth(
db=prisma_client, db=prisma_client,
) )
) )
elif custom_db_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=custom_db_client,
)
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
@ -611,7 +637,7 @@ class ProxyConfig:
""" """
Load config values into proxy global state Load config values into proxy global state
""" """
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -785,11 +811,17 @@ class ProxyConfig:
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ### ### CUSTOM API KEY AUTH ###
## pass filepath
custom_auth = general_settings.get("custom_auth", None) custom_auth = general_settings.get("custom_auth", None)
if custom_auth: if custom_auth is not None:
user_custom_auth = get_instance_fn( user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path value=custom_auth, config_file_path=config_file_path
) )
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and database_type == "dynamo_db":
database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type)
### BACKGROUND HEALTH CHECKS ### ### BACKGROUND HEALTH CHECKS ###
# Enable background health checks # Enable background health checks
use_background_health_checks = general_settings.get( use_background_health_checks = general_settings.get(
@ -856,9 +888,9 @@ async def generate_key_helper_fn(
max_parallel_requests: Optional[int] = None, max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {}, metadata: Optional[dict] = {},
): ):
global prisma_client global prisma_client, custom_db_client
if prisma_client is None: if prisma_client is None and custom_db_client is None:
raise Exception( raise Exception(
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
) )
@ -897,7 +929,12 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4()) user_id = user_id or str(uuid.uuid4())
try: try:
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = { user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id
}
key_data = {
"token": token, "token": token,
"expires": expires, "expires": expires,
"models": models, "models": models,
@ -907,19 +944,24 @@ async def generate_key_helper_fn(
"user_id": user_id, "user_id": user_id,
"max_parallel_requests": max_parallel_requests, "max_parallel_requests": max_parallel_requests,
"metadata": metadata_json, "metadata": metadata_json,
"max_budget": max_budget,
"user_email": user_email,
} }
if prisma_client is not None:
verification_token_data = key_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data") verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
new_verification_token = await prisma_client.insert_data( await prisma_client.insert_data(
data=verification_token_data data=verification_token_data
) )
elif custom_db_client is not None:
## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return { return {
"token": token, "token": token,
"expires": new_verification_token.expires, "expires": expires,
"user_id": user_id, "user_id": user_id,
"max_budget": max_budget, "max_budget": max_budget,
} }
@ -1188,12 +1230,24 @@ async def startup_event():
if prisma_client is not None: if prisma_client is not None:
await prisma_client.connect() await prisma_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
print(f'prisma_client: {prisma_client}')
await generate_key_helper_fn( await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
if custom_db_client is not None and master_key is not None:
# add master key to db
print(f'custom_db_client: {custom_db_client}')
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get( @router.get(

View file

@ -1,11 +1,13 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
from fastapi import HTTPException, status from fastapi import HTTPException, status
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
@ -580,6 +582,52 @@ class PrismaClient:
) )
raise e raise e
class DBClient:
"""
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: Optional[dict]=None) -> None:
if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, value=value, table_name=table_name)
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
"""
For cost tracking logic
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
@ -621,7 +669,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ### ### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
""" """
Check if a user_id exists in cache, Check if a user_id exists in cache,
if not retrieve it. if not retrieve it.
@ -630,7 +678,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
cache_key = f"{user_id}_user_api_key_user_id" cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key) response = cache.get_cache(key=cache_key)
if response is None: # Cache miss if response is None: # Cache miss
if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id) user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
if user_row is not None: if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}") print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable( if hasattr(user_row, "model_dump_json") and callable(