feat(proxy_server.py): abstract config update/writing and support persisting config in db

allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db

 https://github.com/BerriAI/litellm/issues/1322
This commit is contained in:
Krrish Dholakia 2024-01-04 14:44:45 +05:30
parent 6dea0d3115
commit 99d9a825de
4 changed files with 430 additions and 309 deletions

View file

@ -502,232 +502,331 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue """
config = {} Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
try: """
if os.path.exists(config_file_path):
def __init__(self) -> None:
pass
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
global prisma_client, user_config_file_path
file_path = config_file_path or user_config_file_path
if config_file_path is not None:
user_config_file_path = config_file_path user_config_file_path = config_file_path
with open(config_file_path, "r") as file: # Load existing config
config = yaml.safe_load(file) ## Yaml
if os.path.exists(f"{file_path}"):
with open(f"{file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else: else:
raise Exception( config = {
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False" "model_list": [],
"general_settings": {},
"router_settings": {},
"litellm_settings": {},
}
## DB
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
):
_tasks = []
keys = [
"model_list",
"general_settings",
"router_settings",
"litellm_settings",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
return config
async def save_config(self, new_config: dict):
global prisma_client, llm_router, user_config_file_path
# Load existing config
backup_config = await self.get_config()
# Save the updated config
## YAML
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(new_config, config_file, default_flow_style=False)
# update Router - verifies if this is a valid config
try:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=user_config_file_path
) )
except Exception as e: except Exception as e:
raise Exception(f"Exception while reading Config: {e}") traceback.print_exc()
# Revert to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid config passed in")
## PRINT YAML FOR CONFIRMING IT WORKS ## DB - writes valid config to db
printed_yaml = copy.deepcopy(config) """
printed_yaml.pop("environment_variables", None) - Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if (
prisma_client is not None
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
):
### KEY REMOVAL ###
models = new_config.get("model_list", [])
for m in models:
if m.get("litellm_params", {}).get("api_key", None) is not None:
# pop the key
api_key = m["litellm_params"].pop("api_key")
# store in local env
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
os.environ[key_name] = api_key
# save the key name (not the value)
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
print_verbose( async def load_config(
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" self, router: Optional[litellm.Router], config_file_path: str
) ):
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
## ENVIRONMENT VARIABLES # Load existing config
environment_variables = config.get("environment_variables", None) config = await self.get_config(config_file_path=config_file_path)
if environment_variables: ## PRINT YAML FOR CONFIRMING IT WORKS
for key, value in environment_variables.items(): printed_yaml = copy.deepcopy(config)
os.environ[key] = value printed_yaml.pop("environment_variables", None)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) print_verbose(
litellm_settings = config.get("litellm_settings", None) f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
if litellm_settings is None: )
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
cache_params = {} ## ENVIRONMENT VARIABLES
if "cache_params" in litellm_settings: environment_variables = config.get("environment_variables", None)
cache_params_in_config = litellm_settings["cache_params"] if environment_variables:
# overwrie cache_params with cache_params_in_config for key, value in environment_variables.items():
cache_params.update(cache_params_in_config) os.environ[key] = value
cache_type = cache_params.get("type", "redis") ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", None)
if litellm_settings is None:
litellm_settings = {}
if litellm_settings:
# ANSI escape code for blue text
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
for key, value in litellm_settings.items():
if key == "cache":
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
from litellm.caching import Cache
print_verbose(f"passed cache type={cache_type}") cache_params = {}
if "cache_params" in litellm_settings:
cache_params_in_config = litellm_settings["cache_params"]
# overwrie cache_params with cache_params_in_config
cache_params.update(cache_params_in_config)
if cache_type == "redis": cache_type = cache_params.get("type", "redis")
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
cache_params = { print_verbose(f"passed cache type={cache_type}")
"type": cache_type,
"host": cache_host, if cache_type == "redis":
"port": cache_port, cache_host = litellm.get_secret("REDIS_HOST", None)
"password": cache_password, cache_port = litellm.get_secret("REDIS_PORT", None)
} cache_password = litellm.get_secret("REDIS_PASSWORD", None)
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
cache_params = {
"type": cache_type,
"host": cache_host,
"port": cache_port,
"password": cache_password,
}
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
)
print() # noqa
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
litellm.cache = Cache(**cache_params)
print( # noqa print( # noqa
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}"
) # noqa
print( # noqa
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}"
) )
print() # noqa elif key == "callbacks":
litellm.callbacks = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
elif key == "post_call_rules":
litellm.post_call_rules = [
get_instance_fn(value=value, config_file_path=config_file_path)
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables # intialize success callbacks
litellm.cache = Cache(**cache_params) for callback in value:
print( # noqa # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" if "." in callback:
) litellm.success_callback.append(
elif key == "callbacks": get_instance_fn(value=callback)
litellm.callbacks = [ )
get_instance_fn(value=value, config_file_path=config_file_path) # these are litellm callbacks - "langfuse", "sentry", "wandb"
] else:
print_verbose( litellm.success_callback.append(callback)
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" print_verbose(
) f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
elif key == "post_call_rules": )
litellm.post_call_rules = [ elif key == "failure_callback":
get_instance_fn(value=value, config_file_path=config_file_path) litellm.failure_callback = []
]
print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}")
elif key == "success_callback":
litellm.success_callback = []
# intialize success callbacks # intialize success callbacks
for callback in value: for callback in value:
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
if "." in callback: if "." in callback:
litellm.success_callback.append(get_instance_fn(value=callback)) litellm.failure_callback.append(
# these are litellm callbacks - "langfuse", "sentry", "wandb" get_instance_fn(value=callback)
else: )
litellm.success_callback.append(callback) # these are litellm callbacks - "langfuse", "sentry", "wandb"
print_verbose( else:
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" litellm.failure_callback.append(callback)
) print_verbose(
elif key == "failure_callback": f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}"
litellm.failure_callback = [] )
elif key == "cache_params":
# this is set in the cache branch
# see usage here: https://docs.litellm.ai/docs/proxy/caching
pass
else:
setattr(litellm, key, value)
# intialize success callbacks ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
for callback in value: general_settings = config.get("general_settings", {})
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function if general_settings is None:
if "." in callback: general_settings = {}
litellm.failure_callback.append(get_instance_fn(value=callback)) if general_settings:
# these are litellm callbacks - "langfuse", "sentry", "wandb" ### LOAD SECRET MANAGER ###
else: key_management_system = general_settings.get("key_management_system", None)
litellm.failure_callback.append(callback) if key_management_system is not None:
print_verbose( if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" ### LOAD FROM AZURE KEY VAULT ###
) load_from_azure_key_vault(use_azure_key_vault=True)
elif key == "cache_params": elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
# this is set in the cache branch ### LOAD FROM GOOGLE KMS ###
# see usage here: https://docs.litellm.ai/docs/proxy/caching load_google_kms(use_google_kms=True)
pass else:
else: raise ValueError("Invalid Key Management System selected")
setattr(litellm, key, value) ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging load_google_kms(use_google_kms=use_google_kms)
general_settings = config.get("general_settings", {}) ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
if general_settings is None: use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
general_settings = {} load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
if general_settings: ### ALERTING ###
### LOAD SECRET MANAGER ### proxy_logging_obj.update_values(
key_management_system = general_settings.get("key_management_system", None) alerting=general_settings.get("alerting", None),
if key_management_system is not None: alerting_threshold=general_settings.get("alerting_threshold", 600),
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### ALERTING ###
proxy_logging_obj.update_values(
alerting=general_settings.get("alerting", None),
alerting_threshold=general_settings.get("alerting_threshold", 600),
)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
) )
### BACKGROUND HEALTH CHECKS ### ### CONNECT TO DATABASE ###
# Enable background health checks database_url = general_settings.get("database_url", None)
use_background_health_checks = general_settings.get( if database_url and database_url.startswith("os.environ/"):
"background_health_checks", False print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
) database_url = litellm.get_secret(database_url)
health_check_interval = general_settings.get("health_check_interval", 300) print_verbose(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
"background_health_checks", False
)
health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = { router_params: dict = {
"num_retries": 3, "num_retries": 3,
"cache_responses": litellm.cache "cache_responses": litellm.cache
!= None, # cache if user passed in cache values != None, # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
} }
## MODEL LIST
model_list = config.get("model_list", None)
if model_list:
router_params["model_list"] = model_list
print( # noqa
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
) # noqa
for model in model_list:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
available_args = [x for x in arg_spec.args if x not in exclude_args] ## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
for k, v in router_settings.items(): available_args = [x for x in arg_spec.args if x not in exclude_args]
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore for k, v in router_settings.items():
return router, model_list, general_settings if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
proxy_config = ProxyConfig()
async def generate_key_helper_fn( async def generate_key_helper_fn(
@ -856,10 +955,6 @@ def initialize(
if debug == True: # this needs to be first, so users can see Router init debugg if debug == True: # this needs to be first, so users can see Router init debugg
litellm.set_verbose = True litellm.set_verbose = True
dynamic_config = {"general": {}, user_model: {}} dynamic_config = {"general": {}, user_model: {}}
if config:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=config
)
if headers: # model-specific param if headers: # model-specific param
user_headers = headers user_headers = headers
dynamic_config[user_model]["headers"] = headers dynamic_config[user_model]["headers"] = headers
@ -988,7 +1083,7 @@ def parse_cache_control(cache_control):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -1000,10 +1095,26 @@ async def startup_event():
print_verbose(f"worker_config: {worker_config}") print_verbose(f"worker_config: {worker_config}")
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if os.path.isfile(worker_config):
initialize(config=worker_config) if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config)
else: else:
# if not, assume it's a json string # if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG")) worker_config = json.loads(os.getenv("WORKER_CONFIG"))
if worker_config.get("config", None) is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config.pop("config")
)
initialize(**worker_config) initialize(**worker_config)
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
@ -1825,7 +1936,7 @@ async def user_auth(request: Request):
### Check if user email in user table ### Check if user email in user table
response = await prisma_client.get_generic_data( response = await prisma_client.get_generic_data(
key="user_email", value=user_email, db="users" key="user_email", value=user_email, table_name="users"
) )
### if so - generate a 24 hr key with that user id ### if so - generate a 24 hr key with that user id
if response is not None: if response is not None:
@ -1883,16 +1994,13 @@ async def user_update(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def add_new_model(model_params: ModelParams): async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
print_verbose(f"User config path: {user_config_file_path}")
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file) print_verbose(f"User config path: {user_config_file_path}")
else:
config = {"model_list": []}
backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
# Add the new model to the config # Add the new model to the config
model_info = model_params.model_info.json() model_info = model_params.model_info.json()
@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"updated model list: {config['model_list']}") print_verbose(f"updated model list: {config['model_list']}")
# Save the updated config # Save new config
with open(f"{user_config_file_path}", "w") as config_file: await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(status_code=400, detail="Invalid Model passed in")
print_verbose(f"llm_model_list: {llm_model_list}")
return {"message": "Model added successfully"} return {"message": "Model added successfully"}
except Exception as e: except Exception as e:
@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {"model_list": []} # handle base case
all_models = config["model_list"] all_models = config["model_list"]
for model in all_models: for model in all_models:
@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request):
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def delete_model(model_info: ModelInfoDelete): async def delete_model(model_info: ModelInfoDelete):
global llm_router, llm_model_list, general_settings, user_config_file_path global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
try: try:
if not os.path.exists(user_config_file_path): if not os.path.exists(user_config_file_path):
raise HTTPException(status_code=404, detail="Config file does not exist.") raise HTTPException(status_code=404, detail="Config file does not exist.")
with open(user_config_file_path, "r") as config_file: # Load existing config
config = yaml.safe_load(config_file) config = await proxy_config.get_config()
# If model_list is not in the config, nothing can be deleted # If model_list is not in the config, nothing can be deleted
if "model_list" not in config: if len(config.get("model_list", [])) == 0:
raise HTTPException( raise HTTPException(
status_code=404, detail="No model list available in the config." status_code=400, detail="No model list available in the config."
) )
# Check if the model with the specified model_id exists # Check if the model with the specified model_id exists
@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete):
# If the model was not found, return an error # If the model was not found, return an error
if model_to_delete is None: if model_to_delete is None:
raise HTTPException( raise HTTPException(
status_code=404, detail="Model with given model_id not found." status_code=400, detail="Model with given model_id not found."
) )
# Remove model from the list and save the updated config # Remove model from the list and save the updated config
config["model_list"].remove(model_to_delete) config["model_list"].remove(model_to_delete)
with open(user_config_file_path, "w") as config_file:
yaml.dump(config, config_file, default_flow_style=False)
# Update Router
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
# Save updated config
config = await proxy_config.save_config(new_config=config)
return {"message": "Model deleted successfully"} return {"message": "Model deleted successfully"}
except HTTPException as e: except HTTPException as e:
@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML):
Currently supports modifying General Settings + LiteLLM settings Currently supports modifying General Settings + LiteLLM settings
""" """
global llm_router, llm_model_list, general_settings global llm_router, llm_model_list, general_settings, proxy_config
try: try:
# Load existing config # Load existing config
if os.path.exists(f"{user_config_file_path}"): config = await proxy_config.get_config()
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
else:
config = {}
backup_config = copy.deepcopy(config) backup_config = copy.deepcopy(config)
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML):
} }
# Save the updated config # Save the updated config
with open(f"{user_config_file_path}", "w") as config_file: config = await proxy_config.save_config(new_config=config)
yaml.dump(config, config_file, default_flow_style=False)
# update Router
try:
llm_router, llm_model_list, general_settings = load_router_config(
router=llm_router, config_file_path=user_config_file_path
)
except Exception as e:
# Rever to old config instead
with open(f"{user_config_file_path}", "w") as config_file:
yaml.dump(backup_config, config_file, default_flow_style=False)
raise HTTPException(
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
)
return {"message": "Config updated successfully"} return {"message": "Config updated successfully"}
except HTTPException as e: except HTTPException as e:
raise e raise e

View file

@ -25,4 +25,9 @@ model LiteLLM_VerificationToken {
user_id String? user_id String?
max_parallel_requests Int? max_parallel_requests Int?
metadata Json @default("{}") metadata Json @default("{}")
}
model LiteLLM_Config {
param_name String @id
param_value Json?
} }

View file

@ -301,20 +301,24 @@ class PrismaClient:
self, self,
key: str, key: str,
value: Any, value: Any,
db: Literal["users", "keys"], table_name: Literal["users", "keys", "config"],
): ):
""" """
Generic implementation of get data Generic implementation of get data
""" """
try: try:
if db == "users": if table_name == "users":
response = await self.db.litellm_usertable.find_first( response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif db == "keys": elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response return response
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
@ -385,39 +389,66 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data(self, data: dict): async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
try: try:
token = data["token"] if table_name == "user+key":
hashed_token = self.hash_token(token=token) token = data["token"]
db_data = self.jsonify_object(data=data) hashed_token = self.hash_token(token=token)
db_data["token"] = hashed_token db_data = self.jsonify_object(data=data)
max_budget = db_data.pop("max_budget", None) db_data["token"] = hashed_token
user_email = db_data.pop("user_email", None) max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore user_email = db_data.pop("user_email", None)
where={ new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
"token": hashed_token, where={
}, "token": hashed_token,
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
}, },
"update": {}, # don't do anything if it already exists data={
}, "create": {**db_data}, # type: ignore
) "update": {}, # don't do anything if it already exists
return new_verification_token },
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)
@ -527,6 +558,7 @@ class PrismaClient:
async def disconnect(self): async def disconnect(self):
try: try:
await self.db.disconnect() await self.db.disconnect()
self.connected = False
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e) self.proxy_logging_obj.failure_handler(original_exception=e)

View file

@ -9,7 +9,7 @@
import sys, re, binascii, struct import sys, re, binascii, struct
import litellm import litellm
import dotenv, json, traceback, threading, base64 import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
@ -6621,7 +6621,7 @@ def _is_base64(s):
def get_secret( def get_secret(
secret_name: str, secret_name: str,
default_value: Optional[str] = None, default_value: Optional[Union[str, bool]] = None,
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
@ -6672,9 +6672,24 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
return secret try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else: else:
return os.environ.get(secret_name) secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
except Exception as e: except Exception as e:
if default_value is not None: if default_value is not None:
return default_value return default_value