feat(proxy_server.py): abstract config update/writing and support persisting config in db

allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db

 https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
Krrish Dholakia 2024-01-04 14:44:45 +05:30
parent 6dea0d3115
commit 99d9a825de
4 changed files with 430 additions and 309 deletions

View file

@ -502,21 +502,113 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue """
config = {} Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
try: """
if os.path.exists(config_file_path):
def __init__(self) -> None:
pass
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
file_path = config_file_path or user_config_file_path
if config_file_path is not None:
user_config_file_path = config_file_path user_config_file_path = config_file_path
with open(config_file_path, "r") as file: # Load existing config
config = yaml.safe_load(file) ## Yaml
if os.path.exists(f"{file_path}"):
with open(f"{file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else: else:
raise Exception( config = {
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False" "model_list": [],
"general_settings": {},
"router_settings": {},
"litellm_settings": {},
}
## DB
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
):
_tasks = []
keys = [
"model_list",
"general_settings",
"router_settings",
"litellm_settings",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
return config
async def save_config(self, new_config: dict):
global prisma_client, llm_router, user_config_file_path
# Load existing config
backup_config = await self.get_config()
# Save the updated config
## YAML
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(new_config, config_file, default_flow_style=False)
# update Router - verifies if this is a valid config
try:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=user_config_file_path
) )
except Exception as e: except Exception as e:
raise Exception(f"Exception while reading Config: {e}") traceback.print_exc()
# Revert to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid config passed in")
## DB - writes valid config to db
"""
- Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
):
### KEY REMOVAL ###
models = new_config.get("model_list", [])
for m in models:
if m.get("litellm_params", {}).get("api_key", None) is not None:
# pop the key
api_key = m["litellm_params"].pop("api_key")
# store in local env
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
os.environ[key_name] = api_key
# save the key name (not the value)
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
async def load_config(
self, router: Optional[litellm.Router], config_file_path: str
):
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS ## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config) printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None) printed_yaml.pop("environment_variables", None)
@ -604,7 +696,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
for callback in value: for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback: if "." in callback:
litellm.success_callback.append(get_instance_fn(value=callback)) litellm.success_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb" # these are litellm callbacks - "langfuse", "sentry", "wandb"
else: else:
litellm.success_callback.append(callback) litellm.success_callback.append(callback)
@ -618,7 +712,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
for callback in value: for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback: if "." in callback:
litellm.failure_callback.append(get_instance_fn(value=callback)) litellm.failure_callback.append(
get_instance_fn(value=callback)
)
# these are litellm callbacks - "langfuse", "sentry", "wandb" # these are litellm callbacks - "langfuse", "sentry", "wandb"
else: else:
litellm.failure_callback.append(callback) litellm.failure_callback.append(callback)
@ -730,6 +826,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, general_settings return router, model_list, general_settings
proxy_config = ProxyConfig()
async def generate_key_helper_fn( async def generate_key_helper_fn(
duration: Optional[str], duration: Optional[str],
models: list, models: list,
@ -856,10 +955,6 @@ def initialize(
if debug == True: # this needs to be first, so users can see Router init debugg if debug == True: # this needs to be first, so users can see Router init debugg
litellm.set_verbose = True litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=config
)
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
@ -988,7 +1083,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -1000,10 +1095,26 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if os.path.isfile(worker_config):
initialize(config=worker_config) if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
else: else:
# if not, assume it's a json string # if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config) initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
@ -1825,7 +1936,7 @@ async def user_auth(request: Request):
### Check if user email in user table ### Check if user email in user table
response = await prisma_client.get_generic_data( response = await prisma_client.get_generic_data(
key="user_email", value=user_email, db="users" key="user_email", value=user_email, table_name="users"
) )
### if so - generate a 24 hr key with that user id ### if so - generate a 24 hr key with that user id
if response is not None: if response is not None:
@ -1883,16 +1994,13 @@ async def user_update(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def add_new_model(model_params: ModelParams): async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
print_verbose(f"User config path: {user_config_file_path}")
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file) print_verbose(f"User config path: {user_config_file_path}")
else:
config = {"model_list": []}
backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
# Add the new model to the config # Add the new model to the config
model_info = model_params.model_info.json() model_info = model_params.model_info.json()
@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"updated model list: {config['model_list']}") print_verbose(f"updated model list: {config['model_list']}")
# Save the updated config # Save new config
with open(f"{user_config_file_path}", "w") as config_file: await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid Model passed in")
print_verbose(f"llm_model_list: {llm_model_list}")
return {"message": "Model added successfully"} return {"message": "Model added successfully"}
except Exception as e: except Exception as e:
@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []} # handle base case
all_models = config["model_list"] all_models = config["model_list"]
for model in all_models: for model in all_models:
@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def delete_model(model_info: ModelInfoDelete): async def delete_model(model_info: ModelInfoDelete):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
if not os.path.exists(user_config_file_path): if not os.path.exists(user_config_file_path):
raise HTTPException(status_code=404, detail="Config file does not exist.") raise HTTPException(status_code=404, detail="Config file does not exist.")
with open(user_config_file_path, "r") as config_file: # Load existing config
config = yaml.safe_load(config_file) config = await proxy_config.get_config()
# If model_list is not in the config, nothing can be deleted # If model_list is not in the config, nothing can be deleted
if "model_list" not in config: if len(config.get("model_list", [])) == 0:
raise HTTPException( raise HTTPException(
status_code=404, detail="No model list available in the config." status_code=400, detail="No model list available in the config."
) )
# Check if the model with the specified model_id exists # Check if the model with the specified model_id exists
@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete):
# If the model was not found, return an error # If the model was not found, return an error
if model_to_delete is None: if model_to_delete is None:
raise HTTPException( raise HTTPException(
status_code=404, detail="Model with given model_id not found." status_code=400, detail="Model with given model_id not found."
) )
# Remove model from the list and save the updated config # Remove model from the list and save the updated config
config["model_list"].remove(model_to_delete) config["model_list"].remove(model_to_delete)
with open(user_config_file_path, "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# Update Router
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
# Save updated config
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"} return {"message": "Model deleted successfully"}
except HTTPException as e: except HTTPException as e:
@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML):
Currently supports modifying General Settings + LiteLLM settings Currently supports modifying General Settings + LiteLLM settings
""" """
global llm_router, llm_model_list, general_settings global llm_router, llm_model_list, general_settings, proxy_config
try: try:
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {}
backup_config = copy.deepcopy(config) backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML):
} }
# Save the updated config # Save the updated config
with open(f"{user_config_file_path}", "w") as config_file: config = await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
)
return {"message": "Config updated successfully"} return {"message": "Config updated successfully"}
except HTTPException as e: except HTTPException as e:
raise e raise e

View file

@ -26,3 +26,8 @@ model LiteLLM_VerificationToken {
max_parallel_requests Int? max_parallel_requests Int?
metadata Json @default("{}") metadata Json @default("{}")
} }
model LiteLLM_Config {
param_name String @id
param_value Json?
}

View file

@ -301,20 +301,24 @@ class PrismaClient:
self, self,
key: str, key: str,
value: Any, value: Any,
db: Literal["users", "keys"], table_name: Literal["users", "keys", "config"],
): ):
""" """
Generic implementation of get data Generic implementation of get data
""" """
try: try:
if db == "users": if table_name == "users":
response = await self.db.litellm_usertable.find_first( response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif db == "keys": elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response return response
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
@ -385,11 +389,14 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data(self, data: dict): async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
try: try:
if table_name == "user+key":
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
@ -418,6 +425,30 @@ class PrismaClient:
}, },
) )
return new_verification_token return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
@ -527,6 +558,7 @@ class PrismaClient:
async def disconnect(self): async def disconnect(self):
try: try:
await self.db.disconnect() await self.db.disconnect()
self.connected = False
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -9,7 +9,7 @@
import sys, re, binascii, struct import sys, re, binascii, struct
import litellm import litellm
import dotenv, json, traceback, threading, base64 import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
@ -6621,7 +6621,7 @@ def _is_base64(s):
def get_secret( def get_secret(
secret_name: str, secret_name: str,
default_value: Optional[str] = None, default_value: Optional[Union[str, bool]] = None,
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
@ -6672,9 +6672,24 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret return secret
else: else:
return os.environ.get(secret_name) secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
except Exception as e: except Exception as e:
if default_value is not None: if default_value is not None:
return default_value return default_value