diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3d8a5a3493..263189a867 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -4,7 +4,6 @@ from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json - class LiteLLMBase(BaseModel): """ Implements default functions, all pydantic objects should have. @@ -181,6 +180,12 @@ class KeyManagementSystem(enum.Enum): AZURE_KEY_VAULT = "azure_key_vault" LOCAL = "local" +class DynamoDBArgs(LiteLLMBase): + billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"] + read_capacity_units: Optional[int] = None + write_capacity_units: Optional[int] = None + region_name: Optional[str] = None + class ConfigGeneralSettings(LiteLLMBase): """ @@ -206,6 +211,8 @@ class ConfigGeneralSettings(LiteLLMBase): None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", ) + database_type: Optional[Literal["dynamo_db"]] = Field(None, description="to use dynamodb instead of postgres db") + database_args: Optional[DynamoDBArgs] = Field(None, description="custom args for instantiating dynamodb client - e.g. billing provision") otel: Optional[bool] = Field( None, description="[BETA] OpenTelemetry support - this might change, use with caution.", @@ -258,3 +265,35 @@ class ConfigYAML(LiteLLMBase): class Config: protected_namespaces = () + +class DBTableNames(enum.Enum): + user = "LiteLLM_UserTable" + key = "LiteLLM_VerificationToken" + config = "LiteLLM_Config" + +class LiteLLM_VerificationToken(LiteLLMBase): + token: str + spend: float = 0.0 + expires: Union[str, None] + models: List[str] + aliases: Dict[str, str] = {} + config: Dict[str, str] = {} + user_id: Union[str, None] + max_parallel_requests: Union[int, None] + metadata: Dict[str, str] = {} + +class LiteLLM_Config(LiteLLMBase): + param_name: str + param_value: Dict + +class LiteLLM_UserTable(LiteLLMBase): + user_id: str + max_budget: Optional[float] + spend: float = 0.0 + user_email: Optional[str] + + @root_validator(pre=True) + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + return values diff --git a/litellm/proxy/db/base_client.py b/litellm/proxy/db/base_client.py new file mode 100644 index 0000000000..9eb7e5ddbf --- /dev/null +++ b/litellm/proxy/db/base_client.py @@ -0,0 +1,43 @@ +from typing import Any, Literal, List +class CustomDB: + """ + Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow + """ + + def __init__(self) -> None: + pass + + def get_data(self, key: str, value: str, table_name: Literal["user", "key", "config"]): + """ + Check if key valid + """ + pass + + def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + """ + For new key / user logic + """ + pass + + def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + """ + For cost tracking logic + """ + pass + + def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): + """ + For /key/delete endpoint s + """ + + def connect(self, ): + """ + For connecting to db and creating / updating any tables + """ + pass + + def disconnect(self, ): + """ + For closing connection on server shutdown + """ + pass \ No newline at end of file diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py new file mode 100644 index 0000000000..2a12a7b27f --- /dev/null +++ b/litellm/proxy/db/dynamo_db.py @@ -0,0 +1,145 @@ +import json +from aiodynamo.client import Client +from aiodynamo.credentials import Credentials, StaticCredentials +from aiodynamo.http.httpx import HTTPX +from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest +from yarl import URL +from litellm.proxy.db.base_client import CustomDB +from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable +from litellm import get_secret +from typing import Any, List, Literal, Optional +from aiodynamo.expressions import UpdateExpression, F +from aiodynamo.models import ReturnValues +from aiodynamo.http.aiohttp import AIOHTTP +from aiohttp import ClientSession +from datetime import datetime + +class DynamoDBWrapper(CustomDB): + credentials: Credentials + def __init__(self, database_arguments: DynamoDBArgs): + self.throughput_type = None + if database_arguments.billing_mode == "PAY_PER_REQUEST": + self.throughput_type = PayPerRequest() + elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": + self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) + self.region_name = database_arguments.region_name + + async def connect(self): + """ + Connect to DB, and creating / updating any tables + """ + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + ## User + table = client.table(DBTableNames.user.value) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("user_id", KeyType.string)), + ) + ## Token + table = client.table(DBTableNames.key.value) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("token", KeyType.string)), + ) + ## Config + table = client.table(DBTableNames.config.value) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("param_name", KeyType.string)), + ) + + async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + if table_name == DBTableNames.user.name: + table = client.table(DBTableNames.user.value) + elif table_name == DBTableNames.key.name: + table = client.table(DBTableNames.key.value) + elif table_name == DBTableNames.config.name: + table = client.table(DBTableNames.config.value) + + for k, v in value.items(): + if isinstance(v, datetime): + value[k] = v.isoformat() + + await table.put_item(item=value) + + async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + if table_name == DBTableNames.user.name: + table = client.table(DBTableNames.user.value) + elif table_name == DBTableNames.key.name: + table = client.table(DBTableNames.key.value) + elif table_name == DBTableNames.config.name: + table = client.table(DBTableNames.config.value) + + response = await table.get_item({key: value}) + + + if table_name == DBTableNames.user.name: + new_response = LiteLLM_UserTable(**response) + elif table_name == DBTableNames.key.name: + new_response = {} + for k, v in response.items(): # handle json string + if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str): + new_response[k] = json.loads(v) + else: + new_response[k] = v + new_response = LiteLLM_VerificationToken(**new_response) + elif table_name == DBTableNames.config.name: + new_response = LiteLLM_Config(**response) + return new_response + + + async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + key_name = None + data_obj = None + if table_name == DBTableNames.user.name: + table = client.table(DBTableNames.user.value) + key_name = "user_id" + data_obj = LiteLLM_UserTable(user_id=key, **value) + + elif table_name == DBTableNames.key.name: + table = client.table(DBTableNames.key.value) + key_name = "token" + data_obj = LiteLLM_VerificationToken(token=key, **value) + + elif table_name == DBTableNames.config.name: + table = client.table(DBTableNames.config.value) + key_name = "param_name" + data_obj = LiteLLM_Config(param_name=key, **value) + + # Initialize an empty UpdateExpression + update_expression = UpdateExpression() + + # Add updates for each field that has been modified + for field in data_obj.model_fields_set: + # If a Pydantic model has a __fields_set__ attribute, it's a set of fields that were set when the model was instantiated + field_value = getattr(data_obj, field) + if isinstance(field_value, datetime): + field_value = field_value.isoformat() + update_expression = update_expression.set(F(field), field_value) + + # Perform the update in DynamoDB + result = await table.update_item( + key={key_name: key}, + update_expression=update_expression, + return_values=ReturnValues.NONE + ) + return result + + async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']): + """ + Not Implemented yet. + """ + return super().delete_data(keys, table_name) \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e93c9baf11..61f3fbd889 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -68,6 +68,7 @@ def generate_feedback_box(): import litellm from litellm.proxy.utils import ( PrismaClient, + DBClient, get_instance_fn, ProxyLogging, _cache_user_row, @@ -142,6 +143,7 @@ worker_config = None master_key = None otel_logging = False prisma_client: Optional[PrismaClient] = None +custom_db_client: Optional[DBClient] = None user_api_key_cache = DualCache() user_custom_auth = None use_background_health_checks = None @@ -185,12 +187,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: - global master_key, prisma_client, llm_model_list, user_custom_auth + global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client try: if isinstance(api_key, str): api_key = _get_bearer_token(api_key=api_key) ### USER-DEFINED AUTH FUNCTION ### - if user_custom_auth: + if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) return UserAPIKeyAuth.model_validate(response) ### LITELLM-DEFINED AUTH FUNCTION ### @@ -232,7 +234,7 @@ async def user_api_key_auth( ) if ( - prisma_client is None + prisma_client is None and custom_db_client is None ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error raise Exception("No connected db.") @@ -242,9 +244,24 @@ async def user_api_key_auth( if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") - valid_token = await prisma_client.get_data( - token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) - ) + if prisma_client is not None: + valid_token = await prisma_client.get_data( + token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc) + ) + elif custom_db_client is not None: + valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key") + # Token exists, now check expiration. + if valid_token.expires is not None: + expiry_time = datetime.fromisoformat(valid_token.expires) + if expiry_time >= datetime.utcnow(): + # Token exists and is not expired. + return response + else: + # Token exists but is expired. + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="expired user key", + ) verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}") user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) elif valid_token is not None: @@ -280,13 +297,22 @@ async def user_api_key_auth( This makes the user row data accessible to pre-api call hooks. """ - asyncio.create_task( - _cache_user_row( - user_id=valid_token.user_id, - cache=user_api_key_cache, - db=prisma_client, + if prisma_client is not None: + asyncio.create_task( + _cache_user_row( + user_id=valid_token.user_id, + cache=user_api_key_cache, + db=prisma_client, + ) + ) + elif custom_db_client is not None: + asyncio.create_task( + _cache_user_row( + user_id=valid_token.user_id, + cache=user_api_key_cache, + db=custom_db_client, + ) ) - ) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: raise Exception(f"Invalid token") @@ -611,7 +637,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -785,11 +811,17 @@ class ProxyConfig: if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) ### CUSTOM API KEY AUTH ### + ## pass filepath custom_auth = general_settings.get("custom_auth", None) - if custom_auth: + if custom_auth is not None: user_custom_auth = get_instance_fn( value=custom_auth, config_file_path=config_file_path ) + ## dynamodb + database_type = general_settings.get("database_type", None) + if database_type is not None and database_type == "dynamo_db": + database_args = general_settings.get("database_args", None) + custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type) ### BACKGROUND HEALTH CHECKS ### # Enable background health checks use_background_health_checks = general_settings.get( @@ -856,9 +888,9 @@ async def generate_key_helper_fn( max_parallel_requests: Optional[int] = None, metadata: Optional[dict] = {}, ): - global prisma_client + global prisma_client, custom_db_client - if prisma_client is None: + if prisma_client is None and custom_db_client is None: raise Exception( f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " ) @@ -897,7 +929,12 @@ async def generate_key_helper_fn( user_id = user_id or str(uuid.uuid4()) try: # Create a new verification token (you may want to enhance this logic based on your needs) - verification_token_data = { + user_data = { + "max_budget": max_budget, + "user_email": user_email, + "user_id": user_id + } + key_data = { "token": token, "expires": expires, "models": models, @@ -907,19 +944,24 @@ async def generate_key_helper_fn( "user_id": user_id, "max_parallel_requests": max_parallel_requests, "metadata": metadata_json, - "max_budget": max_budget, - "user_email": user_email, } - verbose_proxy_logger.debug("PrismaClient: Before Insert Data") - new_verification_token = await prisma_client.insert_data( - data=verification_token_data - ) + if prisma_client is not None: + verification_token_data = key_data.update(user_data) + verbose_proxy_logger.debug("PrismaClient: Before Insert Data") + await prisma_client.insert_data( + data=verification_token_data + ) + elif custom_db_client is not None: + ## CREATE USER (If necessary) + await custom_db_client.insert_data(value=user_data, table_name="user") + ## CREATE KEY + await custom_db_client.insert_data(value=key_data, table_name="key") except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return { "token": token, - "expires": new_verification_token.expires, + "expires": expires, "user_id": user_id, "max_budget": max_budget, } @@ -1187,13 +1229,25 @@ async def startup_event(): verbose_proxy_logger.debug(f"prisma client - {prisma_client}") if prisma_client is not None: await prisma_client.connect() + + if custom_db_client is not None: + await custom_db_client.connect() if prisma_client is not None and master_key is not None: # add master key to db + print(f'prisma_client: {prisma_client}') await generate_key_helper_fn( duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + if custom_db_client is not None and master_key is not None: + # add master key to db + print(f'custom_db_client: {custom_db_client}') + await generate_key_helper_fn( + duration=None, models=[], aliases={}, config={}, spend=0, token=master_key + ) + + #### API ENDPOINTS #### @router.get( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 798c02b647..833eba24fc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,11 +1,13 @@ -from typing import Optional, List, Any, Literal +from typing import Optional, List, Any, Literal, Union import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import litellm, backoff -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy.db.base_client import CustomDB +from litellm.proxy.db.dynamo_db import DynamoDBWrapper from fastapi import HTTPException, status import smtplib from email.mime.text import MIMEText @@ -332,7 +334,7 @@ class PrismaClient: self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e - + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -580,6 +582,52 @@ class PrismaClient: ) raise e +class DBClient: + """ + Routes requests for CustomAuth + + [TODO] route b/w customauth and prisma + """ + def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: Optional[dict]=None) -> None: + if custom_db_type == "dynamo_db": + self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) + + async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + """ + Check if key valid + """ + return await self.db.get_data(key=key, value=value, table_name=table_name) + + async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + """ + For new key / user logic + """ + return await self.db.insert_data(value=value, table_name=table_name) + + async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]): + """ + For cost tracking logic + """ + return await self.db.update_data(key=key, value=value, table_name=table_name) + + async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]): + """ + For /key/delete endpoints + """ + return await self.db.delete_data(keys=keys, table_name=table_name) + + async def connect(self): + """ + For connecting to db and creating / updating any tables + """ + return await self.db.connect() + + async def disconnect(self): + """ + For closing connection on server shutdown + """ + return await self.db.disconnect() + ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: @@ -621,7 +669,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: ### HELPER FUNCTIONS ### -async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): +async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]): """ Check if a user_id exists in cache, if not retrieve it. @@ -630,7 +678,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): cache_key = f"{user_id}_user_api_key_user_id" response = cache.get_cache(key=cache_key) if response is None: # Cache miss - user_row = await db.get_data(user_id=user_id) + if isinstance(db, PrismaClient): + user_row = await db.get_data(user_id=user_id) + elif isinstance(db, DBClient): + user_row = await db.get_data(key="user_id", value=user_id, table_name="user") if user_row is not None: print_verbose(f"User Row: {user_row}, type = {type(user_row)}") if hasattr(user_row, "model_dump_json") and callable(