mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
145 lines
No EOL
6.5 KiB
Python
145 lines
No EOL
6.5 KiB
Python
import json
|
|
from aiodynamo.client import Client
|
|
from aiodynamo.credentials import Credentials, StaticCredentials
|
|
from aiodynamo.http.httpx import HTTPX
|
|
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
|
|
from yarl import URL
|
|
from litellm.proxy.db.base_client import CustomDB
|
|
from litellm.proxy._types import DynamoDBArgs, DBTableNames, LiteLLM_VerificationToken, LiteLLM_Config, LiteLLM_UserTable
|
|
from litellm import get_secret
|
|
from typing import Any, List, Literal, Optional
|
|
from aiodynamo.expressions import UpdateExpression, F
|
|
from aiodynamo.models import ReturnValues
|
|
from aiodynamo.http.aiohttp import AIOHTTP
|
|
from aiohttp import ClientSession
|
|
from datetime import datetime
|
|
|
|
class DynamoDBWrapper(CustomDB):
|
|
credentials: Credentials
|
|
def __init__(self, database_arguments: DynamoDBArgs):
|
|
self.throughput_type = None
|
|
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
|
self.throughput_type = PayPerRequest()
|
|
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
|
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units)
|
|
self.region_name = database_arguments.region_name
|
|
|
|
async def connect(self):
|
|
"""
|
|
Connect to DB, and creating / updating any tables
|
|
"""
|
|
async with ClientSession() as session:
|
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
|
## User
|
|
table = client.table(DBTableNames.user.value)
|
|
if not await table.exists():
|
|
await table.create(
|
|
self.throughput_type,
|
|
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
|
|
)
|
|
## Token
|
|
table = client.table(DBTableNames.key.value)
|
|
if not await table.exists():
|
|
await table.create(
|
|
self.throughput_type,
|
|
KeySchema(hash_key=KeySpec("token", KeyType.string)),
|
|
)
|
|
## Config
|
|
table = client.table(DBTableNames.config.value)
|
|
if not await table.exists():
|
|
await table.create(
|
|
self.throughput_type,
|
|
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
|
|
)
|
|
|
|
async def insert_data(self, value: Any, table_name: Literal['user', 'key', 'config']):
|
|
async with ClientSession() as session:
|
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
|
table = None
|
|
if table_name == DBTableNames.user.name:
|
|
table = client.table(DBTableNames.user.value)
|
|
elif table_name == DBTableNames.key.name:
|
|
table = client.table(DBTableNames.key.value)
|
|
elif table_name == DBTableNames.config.name:
|
|
table = client.table(DBTableNames.config.value)
|
|
|
|
for k, v in value.items():
|
|
if isinstance(v, datetime):
|
|
value[k] = v.isoformat()
|
|
|
|
await table.put_item(item=value)
|
|
|
|
async def get_data(self, key: str, value: str, table_name: Literal['user', 'key', 'config']):
|
|
async with ClientSession() as session:
|
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
|
table = None
|
|
if table_name == DBTableNames.user.name:
|
|
table = client.table(DBTableNames.user.value)
|
|
elif table_name == DBTableNames.key.name:
|
|
table = client.table(DBTableNames.key.value)
|
|
elif table_name == DBTableNames.config.name:
|
|
table = client.table(DBTableNames.config.value)
|
|
|
|
response = await table.get_item({key: value})
|
|
|
|
|
|
if table_name == DBTableNames.user.name:
|
|
new_response = LiteLLM_UserTable(**response)
|
|
elif table_name == DBTableNames.key.name:
|
|
new_response = {}
|
|
for k, v in response.items(): # handle json string
|
|
if (k == "aliases" or k == "config" or k == "metadata") and v is not None and isinstance(v, str):
|
|
new_response[k] = json.loads(v)
|
|
else:
|
|
new_response[k] = v
|
|
new_response = LiteLLM_VerificationToken(**new_response)
|
|
elif table_name == DBTableNames.config.name:
|
|
new_response = LiteLLM_Config(**response)
|
|
return new_response
|
|
|
|
|
|
async def update_data(self, key: str, value: Any, table_name: Literal['user', 'key', 'config']):
|
|
async with ClientSession() as session:
|
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
|
table = None
|
|
key_name = None
|
|
data_obj = None
|
|
if table_name == DBTableNames.user.name:
|
|
table = client.table(DBTableNames.user.value)
|
|
key_name = "user_id"
|
|
data_obj = LiteLLM_UserTable(user_id=key, **value)
|
|
|
|
elif table_name == DBTableNames.key.name:
|
|
table = client.table(DBTableNames.key.value)
|
|
key_name = "token"
|
|
data_obj = LiteLLM_VerificationToken(token=key, **value)
|
|
|
|
elif table_name == DBTableNames.config.name:
|
|
table = client.table(DBTableNames.config.value)
|
|
key_name = "param_name"
|
|
data_obj = LiteLLM_Config(param_name=key, **value)
|
|
|
|
# Initialize an empty UpdateExpression
|
|
update_expression = UpdateExpression()
|
|
|
|
# Add updates for each field that has been modified
|
|
for field in data_obj.model_fields_set:
|
|
# If a Pydantic model has a __fields_set__ attribute, it's a set of fields that were set when the model was instantiated
|
|
field_value = getattr(data_obj, field)
|
|
if isinstance(field_value, datetime):
|
|
field_value = field_value.isoformat()
|
|
update_expression = update_expression.set(F(field), field_value)
|
|
|
|
# Perform the update in DynamoDB
|
|
result = await table.update_item(
|
|
key={key_name: key},
|
|
update_expression=update_expression,
|
|
return_values=ReturnValues.NONE
|
|
)
|
|
return result
|
|
|
|
async def delete_data(self, keys: List[str], table_name: Literal['user', 'key', 'config']):
|
|
"""
|
|
Not Implemented yet.
|
|
"""
|
|
return super().delete_data(keys, table_name) |