litellm-mirror/litellm/proxy/db/dynamo_db.py

198 lines
8.3 KiB
Python

import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
DBTableNames,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
table = client.table(DBTableNames.user.value)
if not await table.exists():
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
except:
raise Exception(
f"Failed to create table - {DBTableNames.user.value}.\nPlease create a new table called {DBTableNames.user.value}\nAND set `hash_key` as 'user_id'"
)
## Token
table = client.table(DBTableNames.key.value)
if not await table.exists():
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
except:
raise Exception(
f"Failed to create table - {DBTableNames.key.value}.\nPlease create a new table called {DBTableNames.key.value}\nAND set `hash_key` as 'token'"
)
## Config
table = client.table(DBTableNames.config.value)
if not await table.exists():
try:
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
except:
raise Exception(
f"Failed to create table - {DBTableNames.config.value}.\nPlease create a new table called {DBTableNames.config.value}\nAND set `hash_key` as 'token'"
)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(
self, key: str, value: str, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value})
new_response: Any = None
if table_name == DBTableNames.user.name:
new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name:
new_response = {}
for k, v in response.items(): # handle json string
if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
data_obj: Optional[
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]
] = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
key_name = "user_id"
data_obj = LiteLLM_UserTable(user_id=key, **value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
key_name = "token"
data_obj = LiteLLM_VerificationToken(token=key, **value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value)
if data_obj is None:
raise Exception(
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
)
# Initialize an empty UpdateExpression
actions: List = []
for field in data_obj.fields_set():
field_value = getattr(data_obj, field)
# Convert datetime object to ISO8601 string
if isinstance(field_value, datetime):
field_value = field_value.isoformat()
# Accumulate updates
actions.append((F(field), Value(value=field_value)))
update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none,
)
return result
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)