feat(proxy_server.py): abstract config update/writing and support persisting config in db

allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db

 https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
Krrish Dholakia 2024-01-04 14:44:45 +05:30
parent 6dea0d3115
commit 99d9a825de
4 changed files with 430 additions and 309 deletions

View file

@ -502,21 +502,113 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
config = {}
try:
if os.path.exists(config_file_path):
class ProxyConfig:
"""
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
"""
def __init__(self) -> None:
pass
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
file_path = config_file_path or user_config_file_path
if config_file_path is not None:
user_config_file_path = config_file_path
with open(config_file_path, "r") as file:
config = yaml.safe_load(file)
# Load existing config
## Yaml
if os.path.exists(f"{file_path}"):
with open(f"{file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
raise Exception(
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False"
config = {
"model_list": [],
"general_settings": {},
"router_settings": {},
"litellm_settings": {},
}
## DB
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
):
_tasks = []
keys = [
"model_list",
"general_settings",
"router_settings",
"litellm_settings",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
return config
async def save_config(self, new_config: dict):
global prisma_client, llm_router, user_config_file_path
# Load existing config
backup_config = await self.get_config()
# Save the updated config
## YAML
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(new_config, config_file, default_flow_style=False)
# update Router - verifies if this is a valid config
try:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
raise Exception(f"Exception while reading Config: {e}")
traceback.print_exc()
# Revert to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid config passed in")
## DB - writes valid config to db
"""
- Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
):
### KEY REMOVAL ###
models = new_config.get("model_list", [])
for m in models:
if m.get("litellm_params", {}).get("api_key", None) is not None:
# pop the key
api_key = m["litellm_params"].pop("api_key")
# store in local env
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
os.environ[key_name] = api_key
# save the key name (not the value)
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
async def load_config(
self, router: Optional[litellm.Router], config_file_path: str
):
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
@ -604,7 +696,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.success_callback.append(get_instance_fn(value=callback))
litellm.success_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.success_callback.append(callback)
@ -618,7 +712,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback:
litellm.failure_callback.append(get_instance_fn(value=callback))
litellm.failure_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb"
else:
litellm.failure_callback.append(callback)
@ -730,6 +826,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, general_settings
proxy_config = ProxyConfig()
async def generate_key_helper_fn(
duration: Optional[str],
models: list,
@ -856,10 +955,6 @@ def initialize(
if debug == True: # this needs to be first, so users can see Router init debugg
litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=config
)
if headers: # model-specific param
user_headers = headers
dynamic_config[user_model]["headers"] = headers
@ -988,7 +1083,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json
### LOAD MASTER KEY ###
@ -1000,10 +1095,26 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path
if os.path.isfile(worker_config):
initialize(config=worker_config)
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
@ -1825,7 +1936,7 @@ async def user_auth(request: Request):
### Check if user email in user table
response = await prisma_client.get_generic_data(
key="user_email", value=user_email, db="users"
key="user_email", value=user_email, table_name="users"
)
### if so - generate a 24 hr key with that user id
if response is not None:
@ -1883,16 +1994,13 @@ async def user_update(request: Request):
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try:
print_verbose(f"User config path: {user_config_file_path}")
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []}
backup_config = copy.deepcopy(config)
config = await proxy_config.get_config()
print_verbose(f"User config path: {user_config_file_path}")
print_verbose(f"Loaded config: {config}")
# Add the new model to the config
model_info = model_params.model_info.json()
@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"updated model list: {config['model_list']}")
# Save the updated config
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid Model passed in")
print_verbose(f"llm_model_list: {llm_model_list}")
# Save new config
await proxy_config.save_config(new_config=config)
return {"message": "Model added successfully"}
except Exception as e:
@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams):
dependencies=[Depends(user_api_key_auth)],
)
async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path
global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []} # handle base case
config = await proxy_config.get_config()
all_models = config["model_list"]
for model in all_models:
@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request):
dependencies=[Depends(user_api_key_auth)],
)
async def delete_model(model_info: ModelInfoDelete):
global llm_router, llm_model_list, general_settings, user_config_file_path
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try:
if not os.path.exists(user_config_file_path):
raise HTTPException(status_code=404, detail="Config file does not exist.")
with open(user_config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
# Load existing config
config = await proxy_config.get_config()
# If model_list is not in the config, nothing can be deleted
if "model_list" not in config:
if len(config.get("model_list", [])) == 0:
raise HTTPException(
status_code=404, detail="No model list available in the config."
status_code=400, detail="No model list available in the config."
)
# Check if the model with the specified model_id exists
@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete):
# If the model was not found, return an error
if model_to_delete is None:
raise HTTPException(
status_code=404, detail="Model with given model_id not found."
status_code=400, detail="Model with given model_id not found."
)
# Remove model from the list and save the updated config
config["model_list"].remove(model_to_delete)
with open(user_config_file_path, "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# Update Router
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
# Save updated config
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"}
except HTTPException as e:
@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML):
Currently supports modifying General Settings + LiteLLM settings
"""
global llm_router, llm_model_list, general_settings
global llm_router, llm_model_list, general_settings, proxy_config
try:
# Load existing config
if os.path.exists(f"{user_config_file_path}"):
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {}
config = await proxy_config.get_config()
backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}")
@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML):
}
# Save the updated config
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
)
config = await proxy_config.save_config(new_config=config)
return {"message": "Config updated successfully"}
except HTTPException as e:
raise e

View file

@ -26,3 +26,8 @@ model LiteLLM_VerificationToken {
max_parallel_requests Int?
metadata Json @default("{}")
}
model LiteLLM_Config {
param_name String @id
param_value Json?
}

View file

@ -301,20 +301,24 @@ class PrismaClient:
self,
key: str,
value: Any,
db: Literal["users", "keys"],
table_name: Literal["users", "keys", "config"],
):
"""
Generic implementation of get data
"""
try:
if db == "users":
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore
)
elif db == "keys":
elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore
)
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response
except Exception as e:
asyncio.create_task(
@ -385,11 +389,14 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict):
async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
"""
Add a key to the database. If it already exists, do nothing.
"""
try:
if table_name == "user+key":
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
@ -418,6 +425,30 @@ class PrismaClient:
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -527,6 +558,7 @@ class PrismaClient:
async def disconnect(self):
try:
await self.db.disconnect()
self.connected = False
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -9,7 +9,7 @@
import sys, re, binascii, struct
import litellm
import dotenv, json, traceback, threading, base64
import dotenv, json, traceback, threading, base64, ast
import subprocess, os
import litellm, openai
import itertools
@ -6621,7 +6621,7 @@ def _is_base64(s):
def get_secret(
secret_name: str,
default_value: Optional[str] = None,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"):
@ -6672,9 +6672,24 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
return os.environ.get(secret_name)
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
except Exception as e:
if default_value is not None:
return default_value