mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(dynamo_db.py): add cost tracking support for key + user
This commit is contained in:
parent
9b3d78c4f3
commit
f94a37a836
4 changed files with 163 additions and 107 deletions
|
@ -1,5 +1,7 @@
|
||||||
from typing import Any, Literal, List
|
from typing import Any, Literal, List
|
||||||
class CustomDB:
|
|
||||||
|
|
||||||
|
class CustomDB:
|
||||||
"""
|
"""
|
||||||
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
|
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
|
||||||
"""
|
"""
|
||||||
|
@ -7,37 +9,45 @@ class CustomDB:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_data(self, key: str, value: str, table_name: Literal["user", "key", "config"]):
|
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
"""
|
"""
|
||||||
Check if key valid
|
Check if key valid
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||||
"""
|
"""
|
||||||
For new key / user logic
|
For new key / user logic
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
def update_data(
|
||||||
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For cost tracking logic
|
For cost tracking logic
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
|
def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For /key/delete endpoint s
|
For /key/delete endpoint s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def connect(self, ):
|
def connect(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For connecting to db and creating / updating any tables
|
||||||
"""
|
"""
|
||||||
For connecting to db and creating / updating any tables
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def disconnect(self, ):
|
def disconnect(
|
||||||
|
self,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For closing connection on server shutdown
|
For closing connection on server shutdown
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -104,20 +104,22 @@ class DynamoDBWrapper(CustomDB):
|
||||||
|
|
||||||
await table.put_item(item=value)
|
await table.put_item(item=value)
|
||||||
|
|
||||||
async def get_data(
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
self, key: str, value: str, table_name: Literal["user", "key", "config"]
|
|
||||||
):
|
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
table = None
|
table = None
|
||||||
|
key_name = None
|
||||||
if table_name == DBTableNames.user.name:
|
if table_name == DBTableNames.user.name:
|
||||||
table = client.table(DBTableNames.user.value)
|
table = client.table(DBTableNames.user.value)
|
||||||
|
key_name = "user_id"
|
||||||
elif table_name == DBTableNames.key.name:
|
elif table_name == DBTableNames.key.name:
|
||||||
table = client.table(DBTableNames.key.value)
|
table = client.table(DBTableNames.key.value)
|
||||||
|
key_name = "token"
|
||||||
elif table_name == DBTableNames.config.name:
|
elif table_name == DBTableNames.config.name:
|
||||||
table = client.table(DBTableNames.config.value)
|
table = client.table(DBTableNames.config.value)
|
||||||
|
key_name = "param_name"
|
||||||
|
|
||||||
response = await table.get_item({key: value})
|
response = await table.get_item({key_name: key})
|
||||||
|
|
||||||
new_response: Any = None
|
new_response: Any = None
|
||||||
if table_name == DBTableNames.user.name:
|
if table_name == DBTableNames.user.name:
|
||||||
|
@ -139,46 +141,41 @@ class DynamoDBWrapper(CustomDB):
|
||||||
return new_response
|
return new_response
|
||||||
|
|
||||||
async def update_data(
|
async def update_data(
|
||||||
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
|
||||||
):
|
):
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
table = None
|
table = None
|
||||||
key_name = None
|
key_name = None
|
||||||
data_obj: Optional[
|
try:
|
||||||
Union[LiteLLM_Config, LiteLLM_UserTable, LiteLLM_VerificationToken]
|
if table_name == DBTableNames.user.name:
|
||||||
] = None
|
table = client.table(DBTableNames.user.value)
|
||||||
if table_name == DBTableNames.user.name:
|
key_name = "user_id"
|
||||||
table = client.table(DBTableNames.user.value)
|
|
||||||
key_name = "user_id"
|
|
||||||
data_obj = LiteLLM_UserTable(user_id=key, **value)
|
|
||||||
|
|
||||||
elif table_name == DBTableNames.key.name:
|
elif table_name == DBTableNames.key.name:
|
||||||
table = client.table(DBTableNames.key.value)
|
table = client.table(DBTableNames.key.value)
|
||||||
key_name = "token"
|
key_name = "token"
|
||||||
data_obj = LiteLLM_VerificationToken(token=key, **value)
|
|
||||||
|
|
||||||
elif table_name == DBTableNames.config.name:
|
elif table_name == DBTableNames.config.name:
|
||||||
table = client.table(DBTableNames.config.value)
|
table = client.table(DBTableNames.config.value)
|
||||||
key_name = "param_name"
|
key_name = "param_name"
|
||||||
data_obj = LiteLLM_Config(param_name=key, **value)
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid table name. Needs to be one of - {DBTableNames.user.name}, {DBTableNames.key.name}, {DBTableNames.config.name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error connecting to table - {str(e)}")
|
||||||
|
|
||||||
if data_obj is None:
|
|
||||||
raise Exception(
|
|
||||||
f"invalid table name passed in - {table_name}. Unable to load valid data object - {data_obj}."
|
|
||||||
)
|
|
||||||
# Initialize an empty UpdateExpression
|
# Initialize an empty UpdateExpression
|
||||||
|
|
||||||
actions: List = []
|
actions: List = []
|
||||||
for field in data_obj.fields_set():
|
for k, v in value.items():
|
||||||
field_value = getattr(data_obj, field)
|
|
||||||
|
|
||||||
# Convert datetime object to ISO8601 string
|
# Convert datetime object to ISO8601 string
|
||||||
if isinstance(field_value, datetime):
|
if isinstance(v, datetime):
|
||||||
field_value = field_value.isoformat()
|
v = v.isoformat()
|
||||||
|
|
||||||
# Accumulate updates
|
# Accumulate updates
|
||||||
actions.append((F(field), Value(value=field_value)))
|
actions.append((F(k), Value(value=v)))
|
||||||
|
|
||||||
update_expression = UpdateExpression(set_updates=actions)
|
update_expression = UpdateExpression(set_updates=actions)
|
||||||
# Perform the update in DynamoDB
|
# Perform the update in DynamoDB
|
||||||
|
|
|
@ -244,12 +244,15 @@ async def user_api_key_auth(
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
valid_token = await prisma_client.get_data(
|
valid_token = await prisma_client.get_data(
|
||||||
token=api_key, expires=datetime.utcnow().replace(tzinfo=timezone.utc)
|
token=api_key,
|
||||||
|
expires=datetime.utcnow().replace(tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
valid_token = await custom_db_client.get_data(
|
||||||
|
key=api_key, table_name="key"
|
||||||
)
|
)
|
||||||
elif custom_db_client is not None:
|
|
||||||
valid_token = await custom_db_client.get_data(key="token", value=api_key, table_name="key")
|
|
||||||
# Token exists, now check expiration.
|
# Token exists, now check expiration.
|
||||||
if valid_token.expires is not None:
|
if valid_token.expires is not None:
|
||||||
expiry_time = datetime.fromisoformat(valid_token.expires)
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
||||||
|
@ -297,7 +300,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
This makes the user row data accessible to pre-api call hooks.
|
This makes the user row data accessible to pre-api call hooks.
|
||||||
"""
|
"""
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_cache_user_row(
|
_cache_user_row(
|
||||||
user_id=valid_token.user_id,
|
user_id=valid_token.user_id,
|
||||||
|
@ -386,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
|
|
||||||
|
|
||||||
def cost_tracking():
|
def cost_tracking():
|
||||||
global prisma_client
|
global prisma_client, custom_db_client
|
||||||
if prisma_client is not None:
|
if prisma_client is not None or custom_db_client is not None:
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
|
@ -400,7 +403,7 @@ async def track_cost_callback(
|
||||||
start_time=None,
|
start_time=None,
|
||||||
end_time=None, # start/end time for completion
|
end_time=None, # start/end time for completion
|
||||||
):
|
):
|
||||||
global prisma_client
|
global prisma_client, custom_db_client
|
||||||
try:
|
try:
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -419,10 +422,10 @@ async def track_cost_callback(
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and (
|
||||||
await update_prisma_database(
|
prisma_client is not None or custom_db_client is not None
|
||||||
token=user_api_key, response_cost=response_cost
|
):
|
||||||
)
|
await update_database(token=user_api_key, response_cost=response_cost)
|
||||||
elif kwargs["stream"] == False: # for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
response_cost = litellm.completion_cost(
|
response_cost = litellm.completion_cost(
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
|
@ -433,15 +436,17 @@ async def track_cost_callback(
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and (
|
||||||
await update_prisma_database(
|
prisma_client is not None or custom_db_client is not None
|
||||||
|
):
|
||||||
|
await update_database(
|
||||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
token=user_api_key, response_cost=response_cost, user_id=user_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def update_prisma_database(token, response_cost, user_id=None):
|
async def update_database(token, response_cost, user_id=None):
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
||||||
|
@ -451,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
||||||
async def _update_user_db():
|
async def _update_user_db():
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
return
|
return
|
||||||
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
if prisma_client is not None:
|
||||||
|
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
key=user_id, table_name="user"
|
||||||
|
)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
existing_spend = 0
|
existing_spend = 0
|
||||||
else:
|
else:
|
||||||
|
@ -462,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
# Update the cost column for the given user id
|
# Update the cost column for the given user id
|
||||||
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
|
if prisma_client is not None:
|
||||||
|
await prisma_client.update_data(
|
||||||
|
user_id=user_id, data={"spend": new_spend}
|
||||||
|
)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
await custom_db_client.update_data(
|
||||||
|
key=user_id, value={"spend": new_spend}, table_name="user"
|
||||||
|
)
|
||||||
|
|
||||||
### UPDATE KEY SPEND ###
|
### UPDATE KEY SPEND ###
|
||||||
async def _update_key_db():
|
async def _update_key_db():
|
||||||
# Fetch the existing cost for the given token
|
if prisma_client is not None:
|
||||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
# Fetch the existing cost for the given token
|
||||||
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||||
if existing_spend_obj is None:
|
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||||
existing_spend = 0
|
if existing_spend_obj is None:
|
||||||
else:
|
existing_spend = 0
|
||||||
existing_spend = existing_spend_obj.spend
|
else:
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
existing_spend = existing_spend_obj.spend
|
||||||
new_spend = existing_spend + response_cost
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
# Update the cost column for the given token
|
# Update the cost column for the given token
|
||||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
key=token, table_name="key"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
|
# Update the cost column for the given token
|
||||||
|
await custom_db_client.update_data(
|
||||||
|
key=token, value={"spend": new_spend}, table_name="key"
|
||||||
|
)
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(_update_user_db())
|
tasks.append(_update_user_db())
|
||||||
|
@ -802,8 +838,6 @@ class ProxyConfig:
|
||||||
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
|
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
|
||||||
database_url = litellm.get_secret(database_url)
|
database_url = litellm.get_secret(database_url)
|
||||||
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
|
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
|
||||||
## COST TRACKING ##
|
|
||||||
cost_tracking()
|
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get(
|
master_key = general_settings.get(
|
||||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
|
@ -821,7 +855,11 @@ class ProxyConfig:
|
||||||
database_type = general_settings.get("database_type", None)
|
database_type = general_settings.get("database_type", None)
|
||||||
if database_type is not None and database_type == "dynamo_db":
|
if database_type is not None and database_type == "dynamo_db":
|
||||||
database_args = general_settings.get("database_args", None)
|
database_args = general_settings.get("database_args", None)
|
||||||
custom_db_client = DBClient(custom_db_args=database_args, custom_db_type=database_type)
|
custom_db_client = DBClient(
|
||||||
|
custom_db_args=database_args, custom_db_type=database_type
|
||||||
|
)
|
||||||
|
## COST TRACKING ##
|
||||||
|
cost_tracking()
|
||||||
### BACKGROUND HEALTH CHECKS ###
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
# Enable background health checks
|
# Enable background health checks
|
||||||
use_background_health_checks = general_settings.get(
|
use_background_health_checks = general_settings.get(
|
||||||
|
@ -930,9 +968,10 @@ async def generate_key_helper_fn(
|
||||||
try:
|
try:
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
user_data = {
|
user_data = {
|
||||||
"max_budget": max_budget,
|
"max_budget": max_budget,
|
||||||
"user_email": user_email,
|
"user_email": user_email,
|
||||||
"user_id": user_id
|
"user_id": user_id,
|
||||||
|
"spend": spend,
|
||||||
}
|
}
|
||||||
key_data = {
|
key_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -945,17 +984,15 @@ async def generate_key_helper_fn(
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"metadata": metadata_json,
|
"metadata": metadata_json,
|
||||||
}
|
}
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
verification_token_data = dict(key_data)
|
verification_token_data = dict(key_data)
|
||||||
verification_token_data.update(user_data)
|
verification_token_data.update(user_data)
|
||||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
||||||
await prisma_client.insert_data(
|
await prisma_client.insert_data(data=verification_token_data)
|
||||||
data=verification_token_data
|
elif custom_db_client is not None:
|
||||||
)
|
|
||||||
elif custom_db_client is not None:
|
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
await custom_db_client.insert_data(value=user_data, table_name="user")
|
await custom_db_client.insert_data(value=user_data, table_name="user")
|
||||||
## CREATE KEY
|
## CREATE KEY
|
||||||
await custom_db_client.insert_data(value=key_data, table_name="key")
|
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -1230,26 +1267,23 @@ async def startup_event():
|
||||||
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
|
verbose_proxy_logger.debug(f"prisma client - {prisma_client}")
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.connect()
|
await prisma_client.connect()
|
||||||
|
|
||||||
if custom_db_client is not None:
|
if custom_db_client is not None:
|
||||||
await custom_db_client.connect()
|
await custom_db_client.connect()
|
||||||
|
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
# add master key to db
|
# add master key to db
|
||||||
print(f'prisma_client: {prisma_client}')
|
|
||||||
await generate_key_helper_fn(
|
await generate_key_helper_fn(
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_db_client is not None and master_key is not None:
|
if custom_db_client is not None and master_key is not None:
|
||||||
# add master key to db
|
# add master key to db
|
||||||
print(f'custom_db_client: {custom_db_client}')
|
|
||||||
await generate_key_helper_fn(
|
await generate_key_helper_fn(
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get(
|
@router.get(
|
||||||
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
"/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
||||||
|
|
|
@ -255,10 +255,10 @@ class PrismaClient:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||||
)
|
)
|
||||||
## init logging object
|
## init logging object
|
||||||
self.proxy_logging_obj = proxy_logging_obj
|
self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
|
||||||
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
|
if os.getenv("DATABASE_URL", None) is None: # setup hasn't taken place
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
# Save the current working directory
|
# Save the current working directory
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
|
@ -334,7 +334,7 @@ class PrismaClient:
|
||||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
@ -582,47 +582,60 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class DBClient:
|
class DBClient:
|
||||||
"""
|
"""
|
||||||
Routes requests for CustomAuth
|
Routes requests for CustomAuth
|
||||||
|
|
||||||
[TODO] route b/w customauth and prisma
|
[TODO] route b/w customauth and prisma
|
||||||
"""
|
"""
|
||||||
def __init__(self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict) -> None:
|
|
||||||
|
def __init__(
|
||||||
|
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
|
||||||
|
) -> None:
|
||||||
if custom_db_type == "dynamo_db":
|
if custom_db_type == "dynamo_db":
|
||||||
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
|
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
|
||||||
|
|
||||||
async def get_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
"""
|
"""
|
||||||
Check if key valid
|
Check if key valid
|
||||||
"""
|
"""
|
||||||
return await self.db.get_data(key=key, value=value, table_name=table_name)
|
return await self.db.get_data(key=key, table_name=table_name)
|
||||||
|
|
||||||
async def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
async def insert_data(
|
||||||
|
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For new key / user logic
|
For new key / user logic
|
||||||
"""
|
"""
|
||||||
return await self.db.insert_data(value=value, table_name=table_name)
|
return await self.db.insert_data(value=value, table_name=table_name)
|
||||||
|
|
||||||
async def update_data(self, key: str, value: Any, table_name: Literal["user", "key", "config"]):
|
async def update_data(
|
||||||
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For cost tracking logic
|
For cost tracking logic
|
||||||
|
|
||||||
|
key - hash_key value \n
|
||||||
|
value - dict with updated values
|
||||||
"""
|
"""
|
||||||
return await self.db.update_data(key=key, value=value, table_name=table_name)
|
return await self.db.update_data(key=key, value=value, table_name=table_name)
|
||||||
|
|
||||||
async def delete_data(self, keys: List[str], table_name: Literal["user", "key", "config"]):
|
async def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
For /key/delete endpoints
|
For /key/delete endpoints
|
||||||
"""
|
"""
|
||||||
return await self.db.delete_data(keys=keys, table_name=table_name)
|
return await self.db.delete_data(keys=keys, table_name=table_name)
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
For connecting to db and creating / updating any tables
|
For connecting to db and creating / updating any tables
|
||||||
"""
|
"""
|
||||||
return await self.db.connect()
|
return await self.db.connect()
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""
|
"""
|
||||||
For closing connection on server shutdown
|
For closing connection on server shutdown
|
||||||
"""
|
"""
|
||||||
|
@ -669,7 +682,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
|
|
||||||
|
|
||||||
### HELPER FUNCTIONS ###
|
### HELPER FUNCTIONS ###
|
||||||
async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]):
|
async def _cache_user_row(
|
||||||
|
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Check if a user_id exists in cache,
|
Check if a user_id exists in cache,
|
||||||
if not retrieve it.
|
if not retrieve it.
|
||||||
|
@ -681,7 +696,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: Union[PrismaClient
|
||||||
if isinstance(db, PrismaClient):
|
if isinstance(db, PrismaClient):
|
||||||
user_row = await db.get_data(user_id=user_id)
|
user_row = await db.get_data(user_id=user_id)
|
||||||
elif isinstance(db, DBClient):
|
elif isinstance(db, DBClient):
|
||||||
user_row = await db.get_data(key="user_id", value=user_id, table_name="user")
|
user_row = await db.get_data(key=user_id, table_name="user")
|
||||||
if user_row is not None:
|
if user_row is not None:
|
||||||
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
||||||
if hasattr(user_row, "model_dump_json") and callable(
|
if hasattr(user_row, "model_dump_json") and callable(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue