litellm-mirror/litellm/proxy/db/dynamo_db.py
2024-01-09 22:55:18 +05:30

154 lines
No EOL
7.2 KiB
Python

import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import (UpdateExpression, F, Value)
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if database_arguments.read_capacity_units is not None and isinstance(database_arguments.read_capacity_units, int) and database_arguments.write_capacity_units is not None and isinstance(database_arguments.write_capacity_units, int):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}")
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
table = client.table(DBTableNames.user.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
## Token
table = client.table(DBTableNames.key.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
## Config
table = client.table(DBTableNames.config.value)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
response = await table.get_item({key: value})
new_response: Any = None
if table_name == DBTableNames.user.name:
new_response = LiteLLM_UserTable(**response)
elif table_name == DBTableNames.key.name:
new_response = {}
for k, v in response.items(): # handle json string
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == DBTableNames.config.name:
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
data_obj: Optional[Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]] = None
if table_name == DBTableNames.user.name:
table = client.table(DBTableNames.user.value)
key_name = "user_id"
data_obj = LiteLLM_UserTable(user_id=key, **value)
elif table_name == DBTableNames.key.name:
table = client.table(DBTableNames.key.value)
key_name = "token"
data_obj = LiteLLM_VerificationToken(token=key, **value)
elif table_name == DBTableNames.config.name:
table = client.table(DBTableNames.config.value)
key_name = "param_name"
data_obj = LiteLLM_Config(param_name=key, **value)
if data_obj is None:
raise Exception(f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}.")
# Initialize an empty UpdateExpression
actions: List = []
for field in data_obj.fields_set():
field_value = getattr(data_obj, field)
# Convert datetime object to ISO8601 string
if isinstance(field_value, datetime):
field_value = field_value.isoformat()
# Accumulate updates
actions.append((F(field), Value(value=field_value)))
update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none
)
return result
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)