litellm-mirror/litellm/proxy/db/dynamo_db.py

145 lines
No EOL
6.5 KiB
Python

import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
from litellm import get_secret
from typing import Any, List, Literal, Optional
from aiodynamo.expressions import UpdateExpression, F
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units)
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
table = client.table(DBTableNames.user.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
## Token
table = client.table(DBTableNames.key.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
## Config
table = client.table(DBTableNames.config.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value})
if table_name == DBTableNames.user.name:
new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name:
new_response = {}
for k, v in response.items(): # handle json string
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
data_obj = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
key_name = "user_id"
data_obj = LiteLLM_UserTable(user_id=key, **value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
key_name = "token"
data_obj = LiteLLM_VerificationToken(token=key, **value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value)
# Initialize an empty UpdateExpression
update_expression = UpdateExpression()
# Add updates for each field that has been modified
for field in data_obj.model_fields_set:
# If a Pydantic model has a __fields_set__ attribute, it's a set of fields that were set when the model was instantiated
field_value = getattr(data_obj, field)
if isinstance(field_value, datetime):
field_value = field_value.isoformat()
update_expression = update_expression.set(F(field), field_value)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.NONE
)
return result
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)