fix(dynamo_db.py): if table create fails, tell user what the table + hash key needs to be

This commit is contained in:
Krrish Dholakia 2024-01-11 23:01:28 +05:30
parent 0ed54ddc7f
commit 9b3d78c4f3
2 changed files with 75 additions and 31 deletions

View file

@ -14,7 +14,7 @@ repos:
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy . --ignore-missing-imports
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/

View file

@ -5,57 +5,89 @@ from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
from litellm.proxy._types import (
DynamoDBArgs,
DBTableNames,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import (UpdateExpression, F, Value)
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if database_arguments.read_capacity_units is not None and isinstance(database_arguments.read_capacity_units, int) and database_arguments.write_capacity_units is not None and isinstance(database_arguments.write_capacity_units, int):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}")
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
table = client.table(DBTableNames.user.value)
if not await table.exists():
await table.create(
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
## Token
except:
raise Exception(
f"Failed to create table - {DBTableNames.user.value}.\nPlease create a new table called {DBTableNames.user.value}\nAND set `hash_key` as 'user_id'"
)
## Token
table = client.table(DBTableNames.key.value)
if not await table.exists():
await table.create(
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
## Config
except:
raise Exception(
f"Failed to create table - {DBTableNames.key.value}.\nPlease create a new table called {DBTableNames.key.value}\nAND set `hash_key` as 'token'"
)
## Config
table = client.table(DBTableNames.config.value)
if not await table.exists():
await table.create(
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
except:
raise Exception(
f"Failed to create table - {DBTableNames.config.value}.\nPlease create a new table called {DBTableNames.config.value}\nAND set `hash_key` as 'token'"
)
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
@ -65,14 +97,16 @@ class DynamoDBWrapper(CustomDB):
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
async def get_data(
self, key: str, value: str, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
@ -82,7 +116,7 @@ class DynamoDBWrapper(CustomDB):
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value})
new_response: Any = None
@ -90,23 +124,30 @@ class DynamoDBWrapper(CustomDB):
new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name:
new_response = {}
for k, v in response.items(): # handle json string
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
for k, v in response.items(): # handle json string
if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v)
else:
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
data_obj: Optional[Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]] = None
data_obj: Optional[
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]
] = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
key_name = "user_id"
@ -122,10 +163,11 @@ class DynamoDBWrapper(CustomDB):
key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value)
if data_obj is None:
raise Exception(f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}.")
if data_obj is None:
raise Exception(
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
)
# Initialize an empty UpdateExpression
actions: List = []
for field in data_obj.fields_set():
@ -143,12 +185,14 @@ class DynamoDBWrapper(CustomDB):
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none
return_values=ReturnValues.none,
)
return result
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
Not Implemented yet.
Not Implemented yet.
"""
return super().delete_data(keys, table_name)
return super().delete_data(keys, table_name)