fix(dynamo_db.py): if table create fails, tell user what the table + hash key needs to be

This commit is contained in:
Krrish Dholakia 2024-01-11 23:01:28 +05:30
parent 0ed54ddc7f
commit 9b3d78c4f3
2 changed files with 75 additions and 31 deletions

View file

@ -14,7 +14,7 @@ repos:
hooks: hooks:
- id: mypy - id: mypy
name: mypy name: mypy
entry: python3 -m mypy . --ignore-missing-imports entry: python3 -m mypy --ignore-missing-imports
language: system language: system
types: [python] types: [python]
files: ^litellm/ files: ^litellm/

View file

@ -5,57 +5,89 @@ from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL from yarl import URL
from litellm.proxy.db.base_client import CustomDB from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable from litellm.proxy._types import (
DynamoDBArgs,
DBTableNames,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret from litellm import get_secret
from typing import Any, List, Literal, Optional, Union from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import (UpdateExpression, F, Value) from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession from aiohttp import ClientSession
from datetime import datetime from datetime import datetime
class DynamoDBWrapper(CustomDB): class DynamoDBWrapper(CustomDB):
credentials: Credentials credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs): def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST": if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest() self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if database_arguments.read_capacity_units is not None and isinstance(database_arguments.read_capacity_units, int) and database_arguments.write_capacity_units is not None and isinstance(database_arguments.write_capacity_units, int): if (
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore database_arguments.read_capacity_units is not None
else: and isinstance(database_arguments.read_capacity_units, int)
raise Exception(f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}") and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.region_name = database_arguments.region_name self.region_name = database_arguments.region_name
async def connect(self): async def connect(self):
""" """
Connect to DB, and creating / updating any tables Connect to DB, and creating / updating any tables
""" """
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User ## User
table = client.table(DBTableNames.user.value) table = client.table(DBTableNames.user.value)
if not await table.exists(): if not await table.exists():
await table.create( try:
await table.create(
self.throughput_type, self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)), KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
) )
## Token except:
raise Exception(
f"Failed to create table - {DBTableNames.user.value}.\nPlease create a new table called {DBTableNames.user.value}\nAND set `hash_key` as 'user_id'"
)
## Token
table = client.table(DBTableNames.key.value) table = client.table(DBTableNames.key.value)
if not await table.exists(): if not await table.exists():
await table.create( try:
await table.create(
self.throughput_type, self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)), KeySchema(hash_key=KeySpec("token", KeyType.string)),
) )
## Config except:
raise Exception(
f"Failed to create table - {DBTableNames.key.value}.\nPlease create a new table called {DBTableNames.key.value}\nAND set `hash_key` as 'token'"
)
## Config
table = client.table(DBTableNames.config.value) table = client.table(DBTableNames.config.value)
if not await table.exists(): if not await table.exists():
await table.create( try:
await table.create(
self.throughput_type, self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)), KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
) )
except:
raise Exception(
f"Failed to create table - {DBTableNames.config.value}.\nPlease create a new table called {DBTableNames.config.value}\nAND set `hash_key` as 'token'"
)
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']): async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None table = None
@ -65,14 +97,16 @@ class DynamoDBWrapper(CustomDB):
table = client.table(DBTableNames.key.value) table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name: elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value) table = client.table(DBTableNames.config.value)
for k, v in value.items(): for k, v in value.items():
if isinstance(v, datetime): if isinstance(v, datetime):
value[k] = v.isoformat() value[k] = v.isoformat()
await table.put_item(item=value) await table.put_item(item=value)
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']): async def get_data(
self, key: str, value: str, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None table = None
@ -82,7 +116,7 @@ class DynamoDBWrapper(CustomDB):
table = client.table(DBTableNames.key.value) table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name: elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value) table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value}) response = await table.get_item({key: value})
new_response: Any = None new_response: Any = None
@ -90,23 +124,30 @@ class DynamoDBWrapper(CustomDB):
new_response = LiteLLM_UserTable(**response) new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name: elif table_name == DBTableNames.key.name:
new_response = {} new_response = {}
for k, v in response.items(): # handle json string for k, v in response.items(): # handle json string
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str): if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v) new_response[k] = json.loads(v)
else: else:
new_response[k] = v new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response) new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name: elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response) new_response = LiteLLM_Config(**response)
return new_response return new_response
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']): async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session: async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None table = None
key_name = None key_name = None
data_obj: Optional[Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]] = None data_obj: Optional[
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]
] = None
if table_name == DBTableNames.user.name: if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value) table = client.table(DBTableNames.user.value)
key_name = "user_id" key_name = "user_id"
@ -122,10 +163,11 @@ class DynamoDBWrapper(CustomDB):
key_name = "param_name" key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value) data_obj = LiteLLM_Config(param_name=key, **value)
if data_obj is None: if data_obj is None:
raise Exception(f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}.") raise Exception(
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
)
# Initialize an empty UpdateExpression # Initialize an empty UpdateExpression
actions: List = [] actions: List = []
for field in data_obj.fields_set(): for field in data_obj.fields_set():
@ -143,12 +185,14 @@ class DynamoDBWrapper(CustomDB):
result = await table.update_item( result = await table.update_item(
key={key_name: key}, key={key_name: key},
update_expression=update_expression, update_expression=update_expression,
return_values=ReturnValues.none return_values=ReturnValues.none,
) )
return result return result
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']): async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
""" """
Not Implemented yet. Not Implemented yet.
""" """
return super().delete_data(keys, table_name) return super().delete_data(keys, table_name)