From 9b3d78c4f3dcbb7794c6b961f9891f5ac26baf93 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 23:01:28 +0530 Subject: [PATCH] fix(dynamo_db.py): if table create fails, tell user what the table + hash key needs to be --- .pre-commit-config.yaml | 2 +- litellm/proxy/db/dynamo_db.py | 104 ++++++++++++++++++++++++---------- 2 files changed, 75 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d806023db..8ab4e3e92 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: hooks: - id: mypy name: mypy - entry: python3 -m mypy . --ignore-missing-imports + entry: python3 -m mypy --ignore-missing-imports language: system types: [python] files: ^litellm/ \ No newline at end of file diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index fec22ea47..660eee910 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -5,57 +5,89 @@ from aiodynamo.http.httpx import HTTPX from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest from yarl import URL from litellm.proxy.db.base_client import CustomDB -from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable +from litellm.proxy._types import ( + DynamoDBArgs, + DBTableNames, + LiteLLM_VerificationToken, + LiteLLM_Config, + LiteLLM_UserTable, +) from litellm import get_secret from typing import Any, List, Literal, Optional, Union -from aiodynamo.expressions import (UpdateExpression, F, Value) +from aiodynamo.expressions import UpdateExpression, F, Value from aiodynamo.models import ReturnValues from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession from datetime import datetime + class DynamoDBWrapper(CustomDB): credentials: Credentials + def __init__(self, database_arguments: DynamoDBArgs): self.throughput_type = None if database_arguments.billing_mode == "PAY_PER_REQUEST": self.throughput_type = PayPerRequest() elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": - if database_arguments.read_capacity_units is not None and isinstance(database_arguments.read_capacity_units, int) and database_arguments.write_capacity_units is not None and isinstance(database_arguments.write_capacity_units, int): - self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore - else: - raise Exception(f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}") + if ( + database_arguments.read_capacity_units is not None + and isinstance(database_arguments.read_capacity_units, int) + and database_arguments.write_capacity_units is not None + and isinstance(database_arguments.write_capacity_units, int) + ): + self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore + else: + raise Exception( + f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}" + ) self.region_name = database_arguments.region_name async def connect(self): """ - Connect to DB, and creating / updating any tables + Connect to DB, and creating / updating any tables """ async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) ## User table = client.table(DBTableNames.user.value) if not await table.exists(): - await table.create( + try: + await table.create( self.throughput_type, KeySchema(hash_key=KeySpec("user_id", KeyType.string)), ) - ## Token + except: + raise Exception( + f"Failed to create table - {DBTableNames.user.value}.\nPlease create a new table called {DBTableNames.user.value}\nAND set `hash_key` as 'user_id'" + ) + ## Token table = client.table(DBTableNames.key.value) if not await table.exists(): - await table.create( + try: + await table.create( self.throughput_type, KeySchema(hash_key=KeySpec("token", KeyType.string)), ) - ## Config + except: + raise Exception( + f"Failed to create table - {DBTableNames.key.value}.\nPlease create a new table called {DBTableNames.key.value}\nAND set `hash_key` as 'token'" + ) + ## Config table = client.table(DBTableNames.config.value) if not await table.exists(): - await table.create( + try: + await table.create( self.throughput_type, KeySchema(hash_key=KeySpec("param_name", KeyType.string)), ) + except: + raise Exception( + f"Failed to create table - {DBTableNames.config.value}.\nPlease create a new table called {DBTableNames.config.value}\nAND set `hash_key` as 'token'" + ) - async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']): + async def insert_data( + self, value: Any, table_name: Literal["user", "key", "config"] + ): async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) table = None @@ -65,14 +97,16 @@ class DynamoDBWrapper(CustomDB): table = client.table(DBTableNames.key.value) elif table_name == DBTableNames.config.name: table = client.table(DBTableNames.config.value) - + for k, v in value.items(): if isinstance(v, datetime): value[k] = v.isoformat() await table.put_item(item=value) - async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']): + async def get_data( + self, key: str, value: str, table_name: Literal["user", "key", "config"] + ): async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) table = None @@ -82,7 +116,7 @@ class DynamoDBWrapper(CustomDB): table = client.table(DBTableNames.key.value) elif table_name == DBTableNames.config.name: table = client.table(DBTableNames.config.value) - + response = await table.get_item({key: value}) new_response: Any = None @@ -90,23 +124,30 @@ class DynamoDBWrapper(CustomDB): new_response = LiteLLM_UserTable(**response) elif table_name == DBTableNames.key.name: new_response = {} - for k, v in response.items(): # handle json string - if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str): + for k, v in response.items(): # handle json string + if ( + (k == "aliases" or k == "config" or k == "metadata") + and v is not None + and isinstance(v, str) + ): new_response[k] = json.loads(v) - else: + else: new_response[k] = v new_response = LiteLLM_VerificationToken(**new_response) elif table_name == DBTableNames.config.name: new_response = LiteLLM_Config(**response) return new_response - - async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']): + async def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): async with ClientSession() as session: client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) table = None key_name = None - data_obj: Optional[Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]] = None + data_obj: Optional[ + Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken] + ] = None if table_name == DBTableNames.user.name: table = client.table(DBTableNames.user.value) key_name = "user_id" @@ -122,10 +163,11 @@ class DynamoDBWrapper(CustomDB): key_name = "param_name" data_obj = LiteLLM_Config(param_name=key, **value) - if data_obj is None: - raise Exception(f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}.") + if data_obj is None: + raise Exception( + f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}." + ) # Initialize an empty UpdateExpression - actions: List = [] for field in data_obj.fields_set(): @@ -143,12 +185,14 @@ class DynamoDBWrapper(CustomDB): result = await table.update_item( key={key_name: key}, update_expression=update_expression, - return_values=ReturnValues.none + return_values=ReturnValues.none, ) return result - - async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']): + + async def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): """ - Not Implemented yet. + Not Implemented yet. """ - return super().delete_data(keys, table_name) \ No newline at end of file + return super().delete_data(keys, table_name)