feat(proxy_server.py): abstract config update/writing and support persisting config in db

allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db

 https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
Krrish Dholakia 2024-01-04 14:44:45 +05:30
parent 6dea0d3115
commit 99d9a825de
4 changed files with 430 additions and 309 deletions

View file

@ -301,20 +301,24 @@ class PrismaClient:
self,
key: str,
value: Any,
db: Literal["users", "keys"],
table_name: Literal["users", "keys", "config"],
):
"""
Generic implementation of get data
"""
try:
if db == "users":
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore
)
elif db == "keys":
elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore
)
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response
except Exception as e:
asyncio.create_task(
@ -385,39 +389,66 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict):
async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
"""
Add a key to the database. If it already exists, do nothing.
"""
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": hashed_token,
},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
if table_name == "user+key":
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": hashed_token,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -527,6 +558,7 @@ class PrismaClient:
async def disconnect(self):
try:
await self.db.disconnect()
self.connected = False
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)