mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): abstract config update/writing and support persisting config in db
allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
parent
6dea0d3115
commit
99d9a825de
4 changed files with 430 additions and 309 deletions
|
@ -301,20 +301,24 @@ class PrismaClient:
|
|||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
db: Literal["users", "keys"],
|
||||
table_name: Literal["users", "keys", "config"],
|
||||
):
|
||||
"""
|
||||
Generic implementation of get data
|
||||
"""
|
||||
try:
|
||||
if db == "users":
|
||||
if table_name == "users":
|
||||
response = await self.db.litellm_usertable.find_first(
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif db == "keys":
|
||||
elif table_name == "keys":
|
||||
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif table_name == "config":
|
||||
response = await self.db.litellm_config.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
|
@ -385,39 +389,66 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
async def insert_data(
|
||||
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
|
||||
):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
user_email = db_data.pop("user_email", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": data["user_id"],
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
if table_name == "user+key":
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
user_email = db_data.pop("user_email", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": hashed_token,
|
||||
},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": data["user_id"],
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
elif table_name == "config":
|
||||
"""
|
||||
For each param,
|
||||
get the existing table values
|
||||
|
||||
Add the new values
|
||||
|
||||
Update DB
|
||||
"""
|
||||
tasks = []
|
||||
for k, v in data.items():
|
||||
updated_data = v
|
||||
updated_data = json.dumps(updated_data)
|
||||
updated_table_row = self.db.litellm_config.upsert(
|
||||
where={"param_name": k},
|
||||
data={
|
||||
"create": {"param_name": k, "param_value": updated_data},
|
||||
"update": {"param_value": updated_data},
|
||||
},
|
||||
)
|
||||
|
||||
tasks.append(updated_table_row)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -527,6 +558,7 @@ class PrismaClient:
|
|||
async def disconnect(self):
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
self.connected = False
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue