forked from phoenix/litellm-mirror
fix(dynamo_db.py): if table create fails, tell user what the table + hash key needs to be
This commit is contained in:
parent
0ed54ddc7f
commit
9b3d78c4f3
2 changed files with 75 additions and 31 deletions
|
@ -14,7 +14,7 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
name: mypy
|
||||||
entry: python3 -m mypy . --ignore-missing-imports
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
files: ^litellm/
|
files: ^litellm/
|
|
@ -5,57 +5,89 @@ from aiodynamo.http.httpx import HTTPX
|
||||||
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
|
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
|
from litellm.proxy._types import (
|
||||||
|
DynamoDBArgs,
|
||||||
|
DBTableNames,
|
||||||
|
LiteLLM_VerificationToken,
|
||||||
|
LiteLLM_Config,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
)
|
||||||
from litellm import get_secret
|
from litellm import get_secret
|
||||||
from typing import Any, List, Literal, Optional, Union
|
from typing import Any, List, Literal, Optional, Union
|
||||||
from aiodynamo.expressions import (UpdateExpression, F, Value)
|
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||||
from aiodynamo.models import ReturnValues
|
from aiodynamo.models import ReturnValues
|
||||||
from aiodynamo.http.aiohttp import AIOHTTP
|
from aiodynamo.http.aiohttp import AIOHTTP
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class DynamoDBWrapper(CustomDB):
|
class DynamoDBWrapper(CustomDB):
|
||||||
credentials: Credentials
|
credentials: Credentials
|
||||||
|
|
||||||
def __init__(self, database_arguments: DynamoDBArgs):
|
def __init__(self, database_arguments: DynamoDBArgs):
|
||||||
self.throughput_type = None
|
self.throughput_type = None
|
||||||
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||||
self.throughput_type = PayPerRequest()
|
self.throughput_type = PayPerRequest()
|
||||||
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
||||||
if database_arguments.read_capacity_units is not None and isinstance(database_arguments.read_capacity_units, int) and database_arguments.write_capacity_units is not None and isinstance(database_arguments.write_capacity_units, int):
|
if (
|
||||||
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
|
database_arguments.read_capacity_units is not None
|
||||||
else:
|
and isinstance(database_arguments.read_capacity_units, int)
|
||||||
raise Exception(f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}")
|
and database_arguments.write_capacity_units is not None
|
||||||
|
and isinstance(database_arguments.write_capacity_units, int)
|
||||||
|
):
|
||||||
|
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
|
||||||
|
)
|
||||||
self.region_name = database_arguments.region_name
|
self.region_name = database_arguments.region_name
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
Connect to DB, and creating / updating any tables
|
Connect to DB, and creating / updating any tables
|
||||||
"""
|
"""
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
## User
|
## User
|
||||||
table = client.table(DBTableNames.user.value)
|
table = client.table(DBTableNames.user.value)
|
||||||
if not await table.exists():
|
if not await table.exists():
|
||||||
await table.create(
|
try:
|
||||||
|
await table.create(
|
||||||
self.throughput_type,
|
self.throughput_type,
|
||||||
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
|
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
|
||||||
)
|
)
|
||||||
## Token
|
except:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {DBTableNames.user.value}.\nPlease create a new table called {DBTableNames.user.value}\nAND set `hash_key` as 'user_id'"
|
||||||
|
)
|
||||||
|
## Token
|
||||||
table = client.table(DBTableNames.key.value)
|
table = client.table(DBTableNames.key.value)
|
||||||
if not await table.exists():
|
if not await table.exists():
|
||||||
await table.create(
|
try:
|
||||||
|
await table.create(
|
||||||
self.throughput_type,
|
self.throughput_type,
|
||||||
KeySchema(hash_key=KeySpec("token", KeyType.string)),
|
KeySchema(hash_key=KeySpec("token", KeyType.string)),
|
||||||
)
|
)
|
||||||
## Config
|
except:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {DBTableNames.key.value}.\nPlease create a new table called {DBTableNames.key.value}\nAND set `hash_key` as 'token'"
|
||||||
|
)
|
||||||
|
## Config
|
||||||
table = client.table(DBTableNames.config.value)
|
table = client.table(DBTableNames.config.value)
|
||||||
if not await table.exists():
|
if not await table.exists():
|
||||||
await table.create(
|
try:
|
||||||
|
await table.create(
|
||||||
self.throughput_type,
|
self.throughput_type,
|
||||||
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
|
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {DBTableNames.config.value}.\nPlease create a new table called {DBTableNames.config.value}\nAND set `hash_key` as 'token'"
|
||||||
|
)
|
||||||
|
|
||||||
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
|
async def insert_data(
|
||||||
|
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
table = None
|
table = None
|
||||||
|
@ -65,14 +97,16 @@ class DynamoDBWrapper(CustomDB):
|
||||||
table = client.table(DBTableNames.key.value)
|
table = client.table(DBTableNames.key.value)
|
||||||
elif table_name == DBTableNames.config.name:
|
elif table_name == DBTableNames.config.name:
|
||||||
table = client.table(DBTableNames.config.value)
|
table = client.table(DBTableNames.config.value)
|
||||||
|
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
value[k] = v.isoformat()
|
value[k] = v.isoformat()
|
||||||
|
|
||||||
await table.put_item(item=value)
|
await table.put_item(item=value)
|
||||||
|
|
||||||
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
|
async def get_data(
|
||||||
|
self, key: str, value: str, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
table = None
|
table = None
|
||||||
|
@ -82,7 +116,7 @@ class DynamoDBWrapper(CustomDB):
|
||||||
table = client.table(DBTableNames.key.value)
|
table = client.table(DBTableNames.key.value)
|
||||||
elif table_name == DBTableNames.config.name:
|
elif table_name == DBTableNames.config.name:
|
||||||
table = client.table(DBTableNames.config.value)
|
table = client.table(DBTableNames.config.value)
|
||||||
|
|
||||||
response = await table.get_item({key: value})
|
response = await table.get_item({key: value})
|
||||||
|
|
||||||
new_response: Any = None
|
new_response: Any = None
|
||||||
|
@ -90,23 +124,30 @@ class DynamoDBWrapper(CustomDB):
|
||||||
new_response = LiteLLM_UserTable(**response)
|
new_response = LiteLLM_UserTable(**response)
|
||||||
elif table_name == DBTableNames.key.name:
|
elif table_name == DBTableNames.key.name:
|
||||||
new_response = {}
|
new_response = {}
|
||||||
for k, v in response.items(): # handle json string
|
for k, v in response.items(): # handle json string
|
||||||
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
|
if (
|
||||||
|
(k == "aliases" or k == "config" or k == "metadata")
|
||||||
|
and v is not None
|
||||||
|
and isinstance(v, str)
|
||||||
|
):
|
||||||
new_response[k] = json.loads(v)
|
new_response[k] = json.loads(v)
|
||||||
else:
|
else:
|
||||||
new_response[k] = v
|
new_response[k] = v
|
||||||
new_response = LiteLLM_VerificationToken(**new_response)
|
new_response = LiteLLM_VerificationToken(**new_response)
|
||||||
elif table_name == DBTableNames.config.name:
|
elif table_name == DBTableNames.config.name:
|
||||||
new_response = LiteLLM_Config(**response)
|
new_response = LiteLLM_Config(**response)
|
||||||
return new_response
|
return new_response
|
||||||
|
|
||||||
|
|
||||||
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
|
async def update_data(
|
||||||
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
table = None
|
table = None
|
||||||
key_name = None
|
key_name = None
|
||||||
data_obj: Optional[Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]] = None
|
data_obj: Optional[
|
||||||
|
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]
|
||||||
|
] = None
|
||||||
if table_name == DBTableNames.user.name:
|
if table_name == DBTableNames.user.name:
|
||||||
table = client.table(DBTableNames.user.value)
|
table = client.table(DBTableNames.user.value)
|
||||||
key_name = "user_id"
|
key_name = "user_id"
|
||||||
|
@ -122,10 +163,11 @@ class DynamoDBWrapper(CustomDB):
|
||||||
key_name = "param_name"
|
key_name = "param_name"
|
||||||
data_obj = LiteLLM_Config(param_name=key, **value)
|
data_obj = LiteLLM_Config(param_name=key, **value)
|
||||||
|
|
||||||
if data_obj is None:
|
if data_obj is None:
|
||||||
raise Exception(f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}.")
|
raise Exception(
|
||||||
|
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
|
||||||
|
)
|
||||||
# Initialize an empty UpdateExpression
|
# Initialize an empty UpdateExpression
|
||||||
|
|
||||||
|
|
||||||
actions: List = []
|
actions: List = []
|
||||||
for field in data_obj.fields_set():
|
for field in data_obj.fields_set():
|
||||||
|
@ -143,12 +185,14 @@ class DynamoDBWrapper(CustomDB):
|
||||||
result = await table.update_item(
|
result = await table.update_item(
|
||||||
key={key_name: key},
|
key={key_name: key},
|
||||||
update_expression=update_expression,
|
update_expression=update_expression,
|
||||||
return_values=ReturnValues.none
|
return_values=ReturnValues.none,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
|
async def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Not Implemented yet.
|
Not Implemented yet.
|
||||||
"""
|
"""
|
||||||
return super().delete_data(keys, table_name)
|
return super().delete_data(keys, table_name)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue