diff --git a/dist/litellm-1.16.21.dev1-py3-none-any.whl b/dist/litellm-1.16.21.dev1-py3-none-any.whl new file mode 100644 index 0000000000..889690a17d Binary files /dev/null and b/dist/litellm-1.16.21.dev1-py3-none-any.whl differ diff --git a/dist/litellm-1.16.21.dev1.tar.gz b/dist/litellm-1.16.21.dev1.tar.gz new file mode 100644 index 0000000000..17d0c0b76d Binary files /dev/null and b/dist/litellm-1.16.21.dev1.tar.gz differ diff --git a/dist/litellm-1.16.21.dev2-py3-none-any.whl b/dist/litellm-1.16.21.dev2-py3-none-any.whl new file mode 100644 index 0000000000..c174f2628c Binary files /dev/null and b/dist/litellm-1.16.21.dev2-py3-none-any.whl differ diff --git a/dist/litellm-1.16.21.dev2.tar.gz b/dist/litellm-1.16.21.dev2.tar.gz new file mode 100644 index 0000000000..7f81f4553f Binary files /dev/null and b/dist/litellm-1.16.21.dev2.tar.gz differ diff --git a/dist/litellm-1.16.21.dev3-py3-none-any.whl b/dist/litellm-1.16.21.dev3-py3-none-any.whl new file mode 100644 index 0000000000..3ce32c6ac2 Binary files /dev/null and b/dist/litellm-1.16.21.dev3-py3-none-any.whl differ diff --git a/dist/litellm-1.16.21.dev3.tar.gz b/dist/litellm-1.16.21.dev3.tar.gz new file mode 100644 index 0000000000..757044faf9 Binary files /dev/null and b/dist/litellm-1.16.21.dev3.tar.gz differ diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index fbdc3a9067..59dc35f371 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -246,3 +246,35 @@ general_settings: $ litellm --config /path/to/config.yaml ``` + +## [BETA] Dynamo DB + +Only live in `v1.16.21.dev1`. + +### Step 1. Save keys to env + +```env +AWS_ACCESS_KEY_ID = "your-aws-access-key-id" +AWS_SECRET_ACCESS_KEY = "your-aws-secret-access-key" +``` + +### Step 2. Add details to config + +```yaml +general_settings: + master_key: sk-1234 + database_type: "dynamo_db" + database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 + "billing_mode": "PAY_PER_REQUEST", + "region_name": "us-west-2" + } +``` + +### Step 3. Generate Key + +```bash +curl --location 'http://0.0.0.0:8000/key/generate' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{"models": ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": null}' +``` \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3d8a5a3493..77f9aabced 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -17,6 +17,13 @@ class LiteLLMBase(BaseModel): # if using pydantic v1 return self.dict() + def fields_set(self): + try: + return self.model_fields_set # noqa + except: + # if using pydantic v1 + return self.__fields_set__ + ######### Request Class Definition ###### class ProxyChatCompletionRequest(LiteLLMBase): @@ -182,6 +189,16 @@ class KeyManagementSystem(enum.Enum): LOCAL = "local" +class DynamoDBArgs(LiteLLMBase): + billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"] + read_capacity_units: Optional[int] = None + write_capacity_units: Optional[int] = None + region_name: str + user_table_name: str = "LiteLLM_UserTable" + key_table_name: str = "LiteLLM_VerificationToken" + config_table_name: str = "LiteLLM_Config" + + class ConfigGeneralSettings(LiteLLMBase): """ Documents all the fields supported by `general_settings` in config.yaml @@ -206,6 +223,13 @@ class ConfigGeneralSettings(LiteLLMBase): None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", ) + database_type: Optional[Literal["dynamo_db"]] = Field( + None, description="to use dynamodb instead of postgres db" + ) + database_args: Optional[DynamoDBArgs] = Field( + None, + description="custom args for instantiating dynamodb client - e.g. billing provision", + ) otel: Optional[bool] = Field( None, description="[BETA] OpenTelemetry support - this might change, use with caution.", @@ -258,3 +282,33 @@ class ConfigYAML(LiteLLMBase): class Config: protected_namespaces = () + + +class LiteLLM_VerificationToken(LiteLLMBase): + token: str + spend: float = 0.0 + expires: Union[str, None] + models: List[str] + aliases: Dict[str, str] = {} + config: Dict[str, str] = {} + user_id: Union[str, None] + max_parallel_requests: Union[int, None] + metadata: Dict[str, str] = {} + + +class LiteLLM_Config(LiteLLMBase): + param_name: str + param_value: Dict + + +class LiteLLM_UserTable(LiteLLMBase): + user_id: str + max_budget: Optional[float] + spend: float = 0.0 + user_email: Optional[str] + + @root_validator(pre=True) + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + return values diff --git a/litellm/proxy/db/base_client.py b/litellm/proxy/db/base_client.py new file mode 100644 index 0000000000..07f0ecdc47 --- /dev/null +++ b/litellm/proxy/db/base_client.py @@ -0,0 +1,53 @@ +from typing import Any, Literal, List + + +class CustomDB: + """ + Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow + """ + + def __init__(self) -> None: + pass + + def get_data(self, key: str, table_name: Literal["user", "key", "config"]): + """ + Check if key valid + """ + pass + + def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + """ + For new key / user logic + """ + pass + + def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): + """ + For cost tracking logic + """ + pass + + def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): + """ + For /key/delete endpoint s + """ + + def connect( + self, + ): + """ + For connecting to db and creating / updating any tables + """ + pass + + def disconnect( + self, + ): + """ + For closing connection on server shutdown + """ + pass diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py new file mode 100644 index 0000000000..0196b16cf9 --- /dev/null +++ b/litellm/proxy/db/dynamo_db.py @@ -0,0 +1,204 @@ +import json +from aiodynamo.client import Client +from aiodynamo.credentials import Credentials, StaticCredentials +from aiodynamo.http.httpx import HTTPX +from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest +from yarl import URL +from litellm.proxy.db.base_client import CustomDB +from litellm.proxy._types import ( + DynamoDBArgs, + LiteLLM_VerificationToken, + LiteLLM_Config, + LiteLLM_UserTable, +) +from litellm import get_secret +from typing import Any, List, Literal, Optional, Union +from aiodynamo.expressions import UpdateExpression, F, Value +from aiodynamo.models import ReturnValues +from aiodynamo.http.aiohttp import AIOHTTP +from aiohttp import ClientSession +from datetime import datetime + + +class DynamoDBWrapper(CustomDB): + credentials: Credentials + + def __init__(self, database_arguments: DynamoDBArgs): + self.throughput_type = None + if database_arguments.billing_mode == "PAY_PER_REQUEST": + self.throughput_type = PayPerRequest() + elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": + if ( + database_arguments.read_capacity_units is not None + and isinstance(database_arguments.read_capacity_units, int) + and database_arguments.write_capacity_units is not None + and isinstance(database_arguments.write_capacity_units, int) + ): + self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore + else: + raise Exception( + f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}" + ) + self.database_arguments = database_arguments + self.region_name = database_arguments.region_name + + async def connect(self): + """ + Connect to DB, and creating / updating any tables + """ + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + ## User + try: + error_occurred = False + table = client.table(self.database_arguments.user_table_name) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("user_id", KeyType.string)), + ) + except Exception as e: + error_occurred = True + if error_occurred == True: + raise Exception( + f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'" + ) + ## Token + try: + error_occurred = False + table = client.table(self.database_arguments.key_table_name) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("token", KeyType.string)), + ) + except Exception as e: + error_occurred = True + if error_occurred == True: + raise Exception( + f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'" + ) + ## Config + try: + error_occurred = False + table = client.table(self.database_arguments.config_table_name) + if not await table.exists(): + await table.create( + self.throughput_type, + KeySchema(hash_key=KeySpec("param_name", KeyType.string)), + ) + except Exception as e: + error_occurred = True + if error_occurred == True: + raise Exception( + f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'" + ) + + async def insert_data( + self, value: Any, table_name: Literal["user", "key", "config"] + ): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + if table_name == "user": + table = client.table(self.database_arguments.user_table_name) + elif table_name == "key": + table = client.table(self.database_arguments.key_table_name) + elif table_name == "config": + table = client.table(self.database_arguments.config_table_name) + + for k, v in value.items(): + if isinstance(v, datetime): + value[k] = v.isoformat() + + await table.put_item(item=value) + + async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + key_name = None + if table_name == "user": + table = client.table(self.database_arguments.user_table_name) + key_name = "user_id" + elif table_name == "key": + table = client.table(self.database_arguments.key_table_name) + key_name = "token" + elif table_name == "config": + table = client.table(self.database_arguments.config_table_name) + key_name = "param_name" + + response = await table.get_item({key_name: key}) + + new_response: Any = None + if table_name == "user": + new_response = LiteLLM_UserTable(**response) + elif table_name == "key": + new_response = {} + for k, v in response.items(): # handle json string + if ( + (k == "aliases" or k == "config" or k == "metadata") + and v is not None + and isinstance(v, str) + ): + new_response[k] = json.loads(v) + else: + new_response[k] = v + new_response = LiteLLM_VerificationToken(**new_response) + elif table_name == "config": + new_response = LiteLLM_Config(**response) + return new_response + + async def update_data( + self, key: str, value: dict, table_name: Literal["user", "key", "config"] + ): + async with ClientSession() as session: + client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) + table = None + key_name = None + try: + if table_name == "user": + table = client.table(self.database_arguments.user_table_name) + key_name = "user_id" + + elif table_name == "key": + table = client.table(self.database_arguments.key_table_name) + key_name = "token" + + elif table_name == "config": + table = client.table(self.database_arguments.config_table_name) + key_name = "param_name" + else: + raise Exception( + f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}" + ) + except Exception as e: + raise Exception(f"Error connecting to table - {str(e)}") + + # Initialize an empty UpdateExpression + + actions: List = [] + for k, v in value.items(): + # Convert datetime object to ISO8601 string + if isinstance(v, datetime): + v = v.isoformat() + + # Accumulate updates + actions.append((F(k), Value(value=v))) + + update_expression = UpdateExpression(set_updates=actions) + # Perform the update in DynamoDB + result = await table.update_item( + key={key_name: key}, + update_expression=update_expression, + return_values=ReturnValues.none, + ) + return result + + async def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): + """ + Not Implemented yet. + """ + return super().delete_data(keys, table_name) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0ea2793f33..ebb8f6e869 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -67,6 +67,7 @@ def generate_feedback_box(): import litellm from litellm.proxy.utils import ( PrismaClient, + DBClient, get_instance_fn, ProxyLogging, _cache_user_row, @@ -141,6 +142,7 @@ worker_config = None master_key = None otel_logging = False prisma_client: Optional[PrismaClient] = None +custom_db_client: Optional[DBClient] = None user_api_key_cache = DualCache() user_custom_auth = None use_background_health_checks = None @@ -184,12 +186,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: - global master_key, prisma_client, llm_model_list, user_custom_auth + global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client try: if isinstance(api_key, str): api_key = _get_bearer_token(api_key=api_key) ### USER-DEFINED AUTH FUNCTION ### - if user_custom_auth: + if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) return UserAPIKeyAuth.model_validate(response) ### LITELLM-DEFINED AUTH FUNCTION ### @@ -231,7 +233,7 @@ async def user_api_key_auth( ) if ( - prisma_client is None + prisma_client is None and custom_db_client is None ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error raise Exception("No connected db.") @@ -241,15 +243,22 @@ async def user_api_key_auth( if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") - valid_token = await prisma_client.get_data( - token=api_key, - ) - expires = datetime.utcnow().replace(tzinfo=timezone.utc) + if prisma_client is not None: + valid_token = await prisma_client.get_data( + token=api_key, + ) + + expires = datetime.utcnow().replace(tzinfo=timezone.utc) + elif custom_db_client is not None: + valid_token = await custom_db_client.get_data( + key=api_key, table_name="key" + ) # Token exists, now check expiration. - if valid_token.expires is not None and expires is not None: - if valid_token.expires >= expires: + if valid_token.expires is not None: + expiry_time = datetime.fromisoformat(valid_token.expires) + if expiry_time >= datetime.utcnow(): # Token exists and is not expired. - return valid_token + return response else: # Token exists but is expired. raise HTTPException( @@ -291,13 +300,22 @@ async def user_api_key_auth( This makes the user row data accessible to pre-api call hooks. """ - asyncio.create_task( - _cache_user_row( - user_id=valid_token.user_id, - cache=user_api_key_cache, - db=prisma_client, + if prisma_client is not None: + asyncio.create_task( + _cache_user_row( + user_id=valid_token.user_id, + cache=user_api_key_cache, + db=prisma_client, + ) + ) + elif custom_db_client is not None: + asyncio.create_task( + _cache_user_row( + user_id=valid_token.user_id, + cache=user_api_key_cache, + db=custom_db_client, + ) ) - ) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: raise Exception(f"Invalid token") @@ -371,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): def cost_tracking(): - global prisma_client - if prisma_client is not None: + global prisma_client, custom_db_client + if prisma_client is not None or custom_db_client is not None: if isinstance(litellm.success_callback, list): verbose_proxy_logger.debug("setting litellm success callback to track cost") if (track_cost_callback) not in litellm.success_callback: # type: ignore @@ -385,7 +403,7 @@ async def track_cost_callback( start_time=None, end_time=None, # start/end time for completion ): - global prisma_client + global prisma_client, custom_db_client try: # check if it has collected an entire stream response verbose_proxy_logger.debug( @@ -404,10 +422,10 @@ async def track_cost_callback( user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) - if user_api_key and prisma_client: - await update_prisma_database( - token=user_api_key, response_cost=response_cost - ) + if user_api_key and ( + prisma_client is not None or custom_db_client is not None + ): + await update_database(token=user_api_key, response_cost=response_cost) elif kwargs["stream"] == False: # for non streaming responses response_cost = litellm.completion_cost( completion_response=completion_response @@ -418,15 +436,17 @@ async def track_cost_callback( user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) - if user_api_key and prisma_client: - await update_prisma_database( + if user_api_key and ( + prisma_client is not None or custom_db_client is not None + ): + await update_database( token=user_api_key, response_cost=response_cost, user_id=user_id ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") -async def update_prisma_database(token, response_cost, user_id=None): +async def update_database(token, response_cost, user_id=None): try: verbose_proxy_logger.debug( f"Enters prisma db call, token: {token}; user_id: {user_id}" @@ -436,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None): async def _update_user_db(): if user_id is None: return - existing_spend_obj = await prisma_client.get_data(user_id=user_id) + if prisma_client is not None: + existing_spend_obj = await prisma_client.get_data(user_id=user_id) + elif custom_db_client is not None: + existing_spend_obj = await custom_db_client.get_data( + key=user_id, table_name="user" + ) if existing_spend_obj is None: existing_spend = 0 else: @@ -447,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None): verbose_proxy_logger.debug(f"new cost: {new_spend}") # Update the cost column for the given user id - await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) + if prisma_client is not None: + await prisma_client.update_data( + user_id=user_id, data={"spend": new_spend} + ) + elif custom_db_client is not None: + await custom_db_client.update_data( + key=user_id, value={"spend": new_spend}, table_name="user" + ) ### UPDATE KEY SPEND ### async def _update_key_db(): - # Fetch the existing cost for the given token - existing_spend_obj = await prisma_client.get_data(token=token) - verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") - if existing_spend_obj is None: - existing_spend = 0 - else: - existing_spend = existing_spend_obj.spend - # Calculate the new cost by adding the existing cost and response_cost - new_spend = existing_spend + response_cost + if prisma_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await prisma_client.get_data(token=token) + verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost - verbose_proxy_logger.debug(f"new cost: {new_spend}") - # Update the cost column for the given token - await prisma_client.update_data(token=token, data={"spend": new_spend}) + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await prisma_client.update_data(token=token, data={"spend": new_spend}) + elif custom_db_client is not None: + # Fetch the existing cost for the given token + existing_spend_obj = await custom_db_client.get_data( + key=token, table_name="key" + ) + verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") + if existing_spend_obj is None: + existing_spend = 0 + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + verbose_proxy_logger.debug(f"new cost: {new_spend}") + # Update the cost column for the given token + await custom_db_client.update_data( + key=token, value={"spend": new_spend}, table_name="key" + ) tasks = [] tasks.append(_update_user_db()) @@ -622,7 +673,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -787,8 +838,6 @@ class ProxyConfig: verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") database_url = litellm.get_secret(database_url) verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") - ## COST TRACKING ## - cost_tracking() ### MASTER KEY ### master_key = general_settings.get( "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) @@ -796,11 +845,23 @@ class ProxyConfig: if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) ### CUSTOM API KEY AUTH ### + ## pass filepath custom_auth = general_settings.get("custom_auth", None) - if custom_auth: + if custom_auth is not None: user_custom_auth = get_instance_fn( value=custom_auth, config_file_path=config_file_path ) + ## dynamodb + database_type = general_settings.get("database_type", None) + if database_type is not None and ( + database_type == "dynamo_db" or database_type == "dynamodb" + ): + database_args = general_settings.get("database_args", None) + custom_db_client = DBClient( + custom_db_args=database_args, custom_db_type=database_type + ) + ## COST TRACKING ## + cost_tracking() ### BACKGROUND HEALTH CHECKS ### # Enable background health checks use_background_health_checks = general_settings.get( @@ -867,9 +928,9 @@ async def generate_key_helper_fn( max_parallel_requests: Optional[int] = None, metadata: Optional[dict] = {}, ): - global prisma_client + global prisma_client, custom_db_client - if prisma_client is None: + if prisma_client is None and custom_db_client is None: raise Exception( f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " ) @@ -908,7 +969,13 @@ async def generate_key_helper_fn( user_id = user_id or str(uuid.uuid4()) try: # Create a new verification token (you may want to enhance this logic based on your needs) - verification_token_data = { + user_data = { + "max_budget": max_budget, + "user_email": user_email, + "user_id": user_id, + "spend": spend, + } + key_data = { "token": token, "expires": expires, "models": models, @@ -918,19 +985,23 @@ async def generate_key_helper_fn( "user_id": user_id, "max_parallel_requests": max_parallel_requests, "metadata": metadata_json, - "max_budget": max_budget, - "user_email": user_email, } - verbose_proxy_logger.debug("PrismaClient: Before Insert Data") - new_verification_token = await prisma_client.insert_data( - data=verification_token_data - ) + if prisma_client is not None: + verification_token_data = dict(key_data) + verification_token_data.update(user_data) + verbose_proxy_logger.debug("PrismaClient: Before Insert Data") + await prisma_client.insert_data(data=verification_token_data) + elif custom_db_client is not None: + ## CREATE USER (If necessary) + await custom_db_client.insert_data(value=user_data, table_name="user") + ## CREATE KEY + await custom_db_client.insert_data(value=key_data, table_name="key") except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return { "token": token, - "expires": new_verification_token.expires, + "expires": expires, "user_id": user_id, "max_budget": max_budget, } @@ -1174,12 +1245,21 @@ async def startup_event(): if prisma_client is not None: await prisma_client.connect() + if custom_db_client is not None: + await custom_db_client.connect() + if prisma_client is not None and master_key is not None: # add master key to db await generate_key_helper_fn( duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + if custom_db_client is not None and master_key is not None: + # add master key to db + await generate_key_helper_fn( + duration=None, models=[], aliases={}, config={}, spend=0, token=master_key + ) + #### API ENDPOINTS #### @router.get( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index b99dbb9e4b..26dcae8468 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,11 +1,13 @@ -from typing import Optional, List, Any, Literal +from typing import Optional, List, Any, Literal, Union import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import litellm, backoff -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy.db.base_client import CustomDB +from litellm.proxy.db.dynamo_db import DynamoDBWrapper from fastapi import HTTPException, status import smtplib from email.mime.text import MIMEText @@ -578,6 +580,65 @@ class PrismaClient: raise e +class DBClient: + """ + Routes requests for CustomAuth + + [TODO] route b/w customauth and prisma + """ + + def __init__( + self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict + ) -> None: + if custom_db_type == "dynamo_db": + self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) + + async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): + """ + Check if key valid + """ + return await self.db.get_data(key=key, table_name=table_name) + + async def insert_data( + self, value: Any, table_name: Literal["user", "key", "config"] + ): + """ + For new key / user logic + """ + return await self.db.insert_data(value=value, table_name=table_name) + + async def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): + """ + For cost tracking logic + + key - hash_key value \n + value - dict with updated values + """ + return await self.db.update_data(key=key, value=value, table_name=table_name) + + async def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): + """ + For /key/delete endpoints + """ + return await self.db.delete_data(keys=keys, table_name=table_name) + + async def connect(self): + """ + For connecting to db and creating / updating any tables + """ + return await self.db.connect() + + async def disconnect(self): + """ + For closing connection on server shutdown + """ + return await self.db.disconnect() + + ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: try: @@ -618,7 +679,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: ### HELPER FUNCTIONS ### -async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): +async def _cache_user_row( + user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient] +): """ Check if a user_id exists in cache, if not retrieve it. @@ -627,7 +690,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): cache_key = f"{user_id}_user_api_key_user_id" response = cache.get_cache(key=cache_key) if response is None: # Cache miss - user_row = await db.get_data(user_id=user_id) + if isinstance(db, PrismaClient): + user_row = await db.get_data(user_id=user_id) + elif isinstance(db, DBClient): + user_row = await db.get_data(key=user_id, table_name="user") if user_row is not None: print_verbose(f"User Row: {user_row}, type = {type(user_row)}") if hasattr(user_row, "model_dump_json") and callable(