From 99d9a825deadd0a4161797d5b38209fb051fc94e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 4 Jan 2024 14:44:45 +0530 Subject: [PATCH] feat(proxy_server.py): abstract config update/writing and support persisting config in db allows user to opt into writing to db (SAVE_CONFIG_TO_DB) and removes any api keys before sending to db https://github.com/BerriAI/litellm/issues/1322 --- litellm/proxy/proxy_server.py | 617 +++++++++++++++++++--------------- litellm/proxy/schema.prisma | 5 + litellm/proxy/utils.py | 94 ++++-- litellm/utils.py | 23 +- 4 files changed, 430 insertions(+), 309 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fc0d0b608..0431ba11e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -502,232 +502,331 @@ async def _run_background_health_check(): await asyncio.sleep(health_check_interval) -def load_router_config(router: Optional[litellm.Router], config_file_path: str): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue - config = {} - try: - if os.path.exists(config_file_path): +class ProxyConfig: + """ + Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. + """ + + def __init__(self) -> None: + pass + + async def get_config(self, config_file_path: Optional[str] = None) -> dict: + global prisma_client, user_config_file_path + + file_path = config_file_path or user_config_file_path + if config_file_path is not None: user_config_file_path = config_file_path - with open(config_file_path, "r") as file: - config = yaml.safe_load(file) + # Load existing config + ## Yaml + if os.path.exists(f"{file_path}"): + with open(f"{file_path}", "r") as config_file: + config = yaml.safe_load(config_file) else: - raise Exception( - f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False" + config = { + "model_list": [], + "general_settings": {}, + "router_settings": {}, + "litellm_settings": {}, + } + + ## DB + if ( + prisma_client is not None + and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True + ): + _tasks = [] + keys = [ + "model_list", + "general_settings", + "router_settings", + "litellm_settings", + ] + for k in keys: + response = prisma_client.get_generic_data( + key="param_name", value=k, table_name="config" + ) + _tasks.append(response) + + responses = await asyncio.gather(*_tasks) + + return config + + async def save_config(self, new_config: dict): + global prisma_client, llm_router, user_config_file_path + # Load existing config + backup_config = await self.get_config() + + # Save the updated config + ## YAML + with open(f"{user_config_file_path}", "w") as config_file: + yaml.dump(new_config, config_file, default_flow_style=False) + + # update Router - verifies if this is a valid config + try: + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=user_config_file_path ) - except Exception as e: - raise Exception(f"Exception while reading Config: {e}") + except Exception as e: + traceback.print_exc() + # Revert to old config instead + with open(f"{user_config_file_path}", "w") as config_file: + yaml.dump(backup_config, config_file, default_flow_style=False) + raise HTTPException(status_code=400, detail="Invalid config passed in") - ## PRINT YAML FOR CONFIRMING IT WORKS - printed_yaml = copy.deepcopy(config) - printed_yaml.pop("environment_variables", None) + ## DB - writes valid config to db + """ + - Do not write restricted params like 'api_key' to the database + - if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`) + """ + if ( + prisma_client is not None + and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True + ): + ### KEY REMOVAL ### + models = new_config.get("model_list", []) + for m in models: + if m.get("litellm_params", {}).get("api_key", None) is not None: + # pop the key + api_key = m["litellm_params"].pop("api_key") + # store in local env + key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}" + os.environ[key_name] = api_key + # save the key name (not the value) + m["litellm_params"]["api_key"] = f"os.environ/{key_name}" + await prisma_client.insert_data(data=new_config, table_name="config") - print_verbose( - f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" - ) + async def load_config( + self, router: Optional[litellm.Router], config_file_path: str + ): + """ + Load config values into proxy global state + """ + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue - ## ENVIRONMENT VARIABLES - environment_variables = config.get("environment_variables", None) - if environment_variables: - for key, value in environment_variables.items(): - os.environ[key] = value + # Load existing config + config = await self.get_config(config_file_path=config_file_path) + ## PRINT YAML FOR CONFIRMING IT WORKS + printed_yaml = copy.deepcopy(config) + printed_yaml.pop("environment_variables", None) - ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) - litellm_settings = config.get("litellm_settings", None) - if litellm_settings is None: - litellm_settings = {} - if litellm_settings: - # ANSI escape code for blue text - blue_color_code = "\033[94m" - reset_color_code = "\033[0m" - for key, value in litellm_settings.items(): - if key == "cache": - print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa - from litellm.caching import Cache + print_verbose( + f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" + ) - cache_params = {} - if "cache_params" in litellm_settings: - cache_params_in_config = litellm_settings["cache_params"] - # overwrie cache_params with cache_params_in_config - cache_params.update(cache_params_in_config) + ## ENVIRONMENT VARIABLES + environment_variables = config.get("environment_variables", None) + if environment_variables: + for key, value in environment_variables.items(): + os.environ[key] = value - cache_type = cache_params.get("type", "redis") + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) + litellm_settings = config.get("litellm_settings", None) + if litellm_settings is None: + litellm_settings = {} + if litellm_settings: + # ANSI escape code for blue text + blue_color_code = "\033[94m" + reset_color_code = "\033[0m" + for key, value in litellm_settings.items(): + if key == "cache": + print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa + from litellm.caching import Cache - print_verbose(f"passed cache type={cache_type}") + cache_params = {} + if "cache_params" in litellm_settings: + cache_params_in_config = litellm_settings["cache_params"] + # overwrie cache_params with cache_params_in_config + cache_params.update(cache_params_in_config) - if cache_type == "redis": - cache_host = litellm.get_secret("REDIS_HOST", None) - cache_port = litellm.get_secret("REDIS_PORT", None) - cache_password = litellm.get_secret("REDIS_PASSWORD", None) + cache_type = cache_params.get("type", "redis") - cache_params = { - "type": cache_type, - "host": cache_host, - "port": cache_port, - "password": cache_password, - } - # Assuming cache_type, cache_host, cache_port, and cache_password are strings + print_verbose(f"passed cache type={cache_type}") + + if cache_type == "redis": + cache_host = litellm.get_secret("REDIS_HOST", None) + cache_port = litellm.get_secret("REDIS_PORT", None) + cache_password = litellm.get_secret("REDIS_PASSWORD", None) + + cache_params = { + "type": cache_type, + "host": cache_host, + "port": cache_port, + "password": cache_password, + } + # Assuming cache_type, cache_host, cache_port, and cache_password are strings + print( # noqa + f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" + ) # noqa + print( # noqa + f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}" + ) # noqa + print( # noqa + f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}" + ) # noqa + print( # noqa + f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}" + ) + print() # noqa + + ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables + litellm.cache = Cache(**cache_params) print( # noqa - f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" - ) # noqa - print( # noqa - f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}" - ) # noqa - print( # noqa - f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}" - ) # noqa - print( # noqa - f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}" + f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" ) - print() # noqa + elif key == "callbacks": + litellm.callbacks = [ + get_instance_fn(value=value, config_file_path=config_file_path) + ] + print_verbose( + f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" + ) + elif key == "post_call_rules": + litellm.post_call_rules = [ + get_instance_fn(value=value, config_file_path=config_file_path) + ] + print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}") + elif key == "success_callback": + litellm.success_callback = [] - ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables - litellm.cache = Cache(**cache_params) - print( # noqa - f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" - ) - elif key == "callbacks": - litellm.callbacks = [ - get_instance_fn(value=value, config_file_path=config_file_path) - ] - print_verbose( - f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" - ) - elif key == "post_call_rules": - litellm.post_call_rules = [ - get_instance_fn(value=value, config_file_path=config_file_path) - ] - print_verbose(f"litellm.post_call_rules: {litellm.post_call_rules}") - elif key == "success_callback": - litellm.success_callback = [] + # intialize success callbacks + for callback in value: + # user passed custom_callbacks.async_on_succes_logger. They need us to import a function + if "." in callback: + litellm.success_callback.append( + get_instance_fn(value=callback) + ) + # these are litellm callbacks - "langfuse", "sentry", "wandb" + else: + litellm.success_callback.append(callback) + print_verbose( + f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" + ) + elif key == "failure_callback": + litellm.failure_callback = [] - # intialize success callbacks - for callback in value: - # user passed custom_callbacks.async_on_succes_logger. They need us to import a function - if "." in callback: - litellm.success_callback.append(get_instance_fn(value=callback)) - # these are litellm callbacks - "langfuse", "sentry", "wandb" - else: - litellm.success_callback.append(callback) - print_verbose( - f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" - ) - elif key == "failure_callback": - litellm.failure_callback = [] + # intialize success callbacks + for callback in value: + # user passed custom_callbacks.async_on_succes_logger. They need us to import a function + if "." in callback: + litellm.failure_callback.append( + get_instance_fn(value=callback) + ) + # these are litellm callbacks - "langfuse", "sentry", "wandb" + else: + litellm.failure_callback.append(callback) + print_verbose( + f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" + ) + elif key == "cache_params": + # this is set in the cache branch + # see usage here: https://docs.litellm.ai/docs/proxy/caching + pass + else: + setattr(litellm, key, value) - # intialize success callbacks - for callback in value: - # user passed custom_callbacks.async_on_succes_logger. They need us to import a function - if "." in callback: - litellm.failure_callback.append(get_instance_fn(value=callback)) - # these are litellm callbacks - "langfuse", "sentry", "wandb" - else: - litellm.failure_callback.append(callback) - print_verbose( - f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" - ) - elif key == "cache_params": - # this is set in the cache branch - # see usage here: https://docs.litellm.ai/docs/proxy/caching - pass - else: - setattr(litellm, key, value) - - ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging - general_settings = config.get("general_settings", {}) - if general_settings is None: - general_settings = {} - if general_settings: - ### LOAD SECRET MANAGER ### - key_management_system = general_settings.get("key_management_system", None) - if key_management_system is not None: - if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: - ### LOAD FROM AZURE KEY VAULT ### - load_from_azure_key_vault(use_azure_key_vault=True) - elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: - ### LOAD FROM GOOGLE KMS ### - load_google_kms(use_google_kms=True) - else: - raise ValueError("Invalid Key Management System selected") - ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms - use_google_kms = general_settings.get("use_google_kms", False) - load_google_kms(use_google_kms=use_google_kms) - ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager - use_azure_key_vault = general_settings.get("use_azure_key_vault", False) - load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) - ### ALERTING ### - proxy_logging_obj.update_values( - alerting=general_settings.get("alerting", None), - alerting_threshold=general_settings.get("alerting_threshold", 600), - ) - ### CONNECT TO DATABASE ### - database_url = general_settings.get("database_url", None) - if database_url and database_url.startswith("os.environ/"): - print_verbose(f"GOING INTO LITELLM.GET_SECRET!") - database_url = litellm.get_secret(database_url) - print_verbose(f"RETRIEVED DB URL: {database_url}") - prisma_setup(database_url=database_url) - ## COST TRACKING ## - cost_tracking() - ### MASTER KEY ### - master_key = general_settings.get( - "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) - ) - if master_key and master_key.startswith("os.environ/"): - master_key = litellm.get_secret(master_key) - ### CUSTOM API KEY AUTH ### - custom_auth = general_settings.get("custom_auth", None) - if custom_auth: - user_custom_auth = get_instance_fn( - value=custom_auth, config_file_path=config_file_path + ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging + general_settings = config.get("general_settings", {}) + if general_settings is None: + general_settings = {} + if general_settings: + ### LOAD SECRET MANAGER ### + key_management_system = general_settings.get("key_management_system", None) + if key_management_system is not None: + if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: + ### LOAD FROM AZURE KEY VAULT ### + load_from_azure_key_vault(use_azure_key_vault=True) + elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: + ### LOAD FROM GOOGLE KMS ### + load_google_kms(use_google_kms=True) + else: + raise ValueError("Invalid Key Management System selected") + ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms + use_google_kms = general_settings.get("use_google_kms", False) + load_google_kms(use_google_kms=use_google_kms) + ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager + use_azure_key_vault = general_settings.get("use_azure_key_vault", False) + load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) + ### ALERTING ### + proxy_logging_obj.update_values( + alerting=general_settings.get("alerting", None), + alerting_threshold=general_settings.get("alerting_threshold", 600), ) - ### BACKGROUND HEALTH CHECKS ### - # Enable background health checks - use_background_health_checks = general_settings.get( - "background_health_checks", False - ) - health_check_interval = general_settings.get("health_check_interval", 300) + ### CONNECT TO DATABASE ### + database_url = general_settings.get("database_url", None) + if database_url and database_url.startswith("os.environ/"): + print_verbose(f"GOING INTO LITELLM.GET_SECRET!") + database_url = litellm.get_secret(database_url) + print_verbose(f"RETRIEVED DB URL: {database_url}") + prisma_setup(database_url=database_url) + ## COST TRACKING ## + cost_tracking() + ### MASTER KEY ### + master_key = general_settings.get( + "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) + ) + if master_key and master_key.startswith("os.environ/"): + master_key = litellm.get_secret(master_key) + ### CUSTOM API KEY AUTH ### + custom_auth = general_settings.get("custom_auth", None) + if custom_auth: + user_custom_auth = get_instance_fn( + value=custom_auth, config_file_path=config_file_path + ) + ### BACKGROUND HEALTH CHECKS ### + # Enable background health checks + use_background_health_checks = general_settings.get( + "background_health_checks", False + ) + health_check_interval = general_settings.get("health_check_interval", 300) - router_params: dict = { - "num_retries": 3, - "cache_responses": litellm.cache - != None, # cache if user passed in cache values - } - ## MODEL LIST - model_list = config.get("model_list", None) - if model_list: - router_params["model_list"] = model_list - print( # noqa - f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m" - ) # noqa - for model in model_list: - ### LOAD FROM os.environ/ ### - for k, v in model["litellm_params"].items(): - if isinstance(v, str) and v.startswith("os.environ/"): - model["litellm_params"][k] = litellm.get_secret(v) - print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa - litellm_model_name = model["litellm_params"]["model"] - litellm_model_api_base = model["litellm_params"].get("api_base", None) - if "ollama" in litellm_model_name and litellm_model_api_base is None: - run_ollama_serve() - - ## ROUTER SETTINGS (e.g. routing_strategy, ...) - router_settings = config.get("router_settings", None) - if router_settings and isinstance(router_settings, dict): - arg_spec = inspect.getfullargspec(litellm.Router) - # model list already set - exclude_args = { - "self", - "model_list", + router_params: dict = { + "num_retries": 3, + "cache_responses": litellm.cache + != None, # cache if user passed in cache values } + ## MODEL LIST + model_list = config.get("model_list", None) + if model_list: + router_params["model_list"] = model_list + print( # noqa + f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m" + ) # noqa + for model in model_list: + ### LOAD FROM os.environ/ ### + for k, v in model["litellm_params"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + model["litellm_params"][k] = litellm.get_secret(v) + print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa + litellm_model_name = model["litellm_params"]["model"] + litellm_model_api_base = model["litellm_params"].get("api_base", None) + if "ollama" in litellm_model_name and litellm_model_api_base is None: + run_ollama_serve() - available_args = [x for x in arg_spec.args if x not in exclude_args] + ## ROUTER SETTINGS (e.g. routing_strategy, ...) + router_settings = config.get("router_settings", None) + if router_settings and isinstance(router_settings, dict): + arg_spec = inspect.getfullargspec(litellm.Router) + # model list already set + exclude_args = { + "self", + "model_list", + } - for k, v in router_settings.items(): - if k in available_args: - router_params[k] = v + available_args = [x for x in arg_spec.args if x not in exclude_args] - router = litellm.Router(**router_params) # type:ignore - return router, model_list, general_settings + for k, v in router_settings.items(): + if k in available_args: + router_params[k] = v + + router = litellm.Router(**router_params) # type:ignore + return router, model_list, general_settings + + +proxy_config = ProxyConfig() async def generate_key_helper_fn( @@ -856,10 +955,6 @@ def initialize( if debug == True: # this needs to be first, so users can see Router init debugg litellm.set_verbose = True dynamic_config = {"general": {}, user_model: {}} - if config: - llm_router, llm_model_list, general_settings = load_router_config( - router=llm_router, config_file_path=config - ) if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers @@ -988,7 +1083,7 @@ def parse_cache_control(cache_control): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings import json ### LOAD MASTER KEY ### @@ -1000,10 +1095,26 @@ async def startup_event(): print_verbose(f"worker_config: {worker_config}") # check if it's a valid file path if os.path.isfile(worker_config): - initialize(config=worker_config) + if worker_config.get("config", None) is not None: + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config.pop("config") + ) + initialize(**worker_config) else: # if not, assume it's a json string worker_config = json.loads(os.getenv("WORKER_CONFIG")) + if worker_config.get("config", None) is not None: + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config.pop("config") + ) initialize(**worker_config) proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made @@ -1825,7 +1936,7 @@ async def user_auth(request: Request): ### Check if user email in user table response = await prisma_client.get_generic_data( - key="user_email", value=user_email, db="users" + key="user_email", value=user_email, table_name="users" ) ### if so - generate a 24 hr key with that user id if response is not None: @@ -1883,16 +1994,13 @@ async def user_update(request: Request): dependencies=[Depends(user_api_key_auth)], ) async def add_new_model(model_params: ModelParams): - global llm_router, llm_model_list, general_settings, user_config_file_path + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config try: - print_verbose(f"User config path: {user_config_file_path}") # Load existing config - if os.path.exists(f"{user_config_file_path}"): - with open(f"{user_config_file_path}", "r") as config_file: - config = yaml.safe_load(config_file) - else: - config = {"model_list": []} - backup_config = copy.deepcopy(config) + config = await proxy_config.get_config() + + print_verbose(f"User config path: {user_config_file_path}") + print_verbose(f"Loaded config: {config}") # Add the new model to the config model_info = model_params.model_info.json() @@ -1907,22 +2015,8 @@ async def add_new_model(model_params: ModelParams): print_verbose(f"updated model list: {config['model_list']}") - # Save the updated config - with open(f"{user_config_file_path}", "w") as config_file: - yaml.dump(config, config_file, default_flow_style=False) - - # update Router - try: - llm_router, llm_model_list, general_settings = load_router_config( - router=llm_router, config_file_path=user_config_file_path - ) - except Exception as e: - # Rever to old config instead - with open(f"{user_config_file_path}", "w") as config_file: - yaml.dump(backup_config, config_file, default_flow_style=False) - raise HTTPException(status_code=400, detail="Invalid Model passed in") - - print_verbose(f"llm_model_list: {llm_model_list}") + # Save new config + await proxy_config.save_config(new_config=config) return {"message": "Model added successfully"} except Exception as e: @@ -1949,13 +2043,10 @@ async def add_new_model(model_params: ModelParams): dependencies=[Depends(user_api_key_auth)], ) async def model_info_v1(request: Request): - global llm_model_list, general_settings, user_config_file_path + global llm_model_list, general_settings, user_config_file_path, proxy_config + # Load existing config - if os.path.exists(f"{user_config_file_path}"): - with open(f"{user_config_file_path}", "r") as config_file: - config = yaml.safe_load(config_file) - else: - config = {"model_list": []} # handle base case + config = await proxy_config.get_config() all_models = config["model_list"] for model in all_models: @@ -1984,18 +2075,18 @@ async def model_info_v1(request: Request): dependencies=[Depends(user_api_key_auth)], ) async def delete_model(model_info: ModelInfoDelete): - global llm_router, llm_model_list, general_settings, user_config_file_path + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config try: if not os.path.exists(user_config_file_path): raise HTTPException(status_code=404, detail="Config file does not exist.") - with open(user_config_file_path, "r") as config_file: - config = yaml.safe_load(config_file) + # Load existing config + config = await proxy_config.get_config() # If model_list is not in the config, nothing can be deleted - if "model_list" not in config: + if len(config.get("model_list", [])) == 0: raise HTTPException( - status_code=404, detail="No model list available in the config." + status_code=400, detail="No model list available in the config." ) # Check if the model with the specified model_id exists @@ -2008,19 +2099,14 @@ async def delete_model(model_info: ModelInfoDelete): # If the model was not found, return an error if model_to_delete is None: raise HTTPException( - status_code=404, detail="Model with given model_id not found." + status_code=400, detail="Model with given model_id not found." ) # Remove model from the list and save the updated config config["model_list"].remove(model_to_delete) - with open(user_config_file_path, "w") as config_file: - yaml.dump(config, config_file, default_flow_style=False) - - # Update Router - llm_router, llm_model_list, general_settings = load_router_config( - router=llm_router, config_file_path=user_config_file_path - ) + # Save updated config + config = await proxy_config.save_config(new_config=config) return {"message": "Model deleted successfully"} except HTTPException as e: @@ -2200,14 +2286,11 @@ async def update_config(config_info: ConfigYAML): Currently supports modifying General Settings + LiteLLM settings """ - global llm_router, llm_model_list, general_settings + global llm_router, llm_model_list, general_settings, proxy_config try: # Load existing config - if os.path.exists(f"{user_config_file_path}"): - with open(f"{user_config_file_path}", "r") as config_file: - config = yaml.safe_load(config_file) - else: - config = {} + config = await proxy_config.get_config() + backup_config = copy.deepcopy(config) print_verbose(f"Loaded config: {config}") @@ -2240,21 +2323,7 @@ async def update_config(config_info: ConfigYAML): } # Save the updated config - with open(f"{user_config_file_path}", "w") as config_file: - yaml.dump(config, config_file, default_flow_style=False) - - # update Router - try: - llm_router, llm_model_list, general_settings = load_router_config( - router=llm_router, config_file_path=user_config_file_path - ) - except Exception as e: - # Rever to old config instead - with open(f"{user_config_file_path}", "w") as config_file: - yaml.dump(backup_config, config_file, default_flow_style=False) - raise HTTPException( - status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}" - ) + config = await proxy_config.save_config(new_config=config) return {"message": "Config updated successfully"} except HTTPException as e: raise e diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 7ce05f285..d12cac8f2 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -25,4 +25,9 @@ model LiteLLM_VerificationToken { user_id String? max_parallel_requests Int? metadata Json @default("{}") +} + +model LiteLLM_Config { + param_name String @id + param_value Json? } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c727c7988..0be448119 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -301,20 +301,24 @@ class PrismaClient: self, key: str, value: Any, - db: Literal["users", "keys"], + table_name: Literal["users", "keys", "config"], ): """ Generic implementation of get data """ try: - if db == "users": + if table_name == "users": response = await self.db.litellm_usertable.find_first( where={key: value} # type: ignore ) - elif db == "keys": + elif table_name == "keys": response = await self.db.litellm_verificationtoken.find_first( # type: ignore where={key: value} # type: ignore ) + elif table_name == "config": + response = await self.db.litellm_config.find_first( # type: ignore + where={key: value} # type: ignore + ) return response except Exception as e: asyncio.create_task( @@ -385,39 +389,66 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def insert_data(self, data: dict): + async def insert_data( + self, data: dict, table_name: Literal["user+key", "config"] = "user+key" + ): """ Add a key to the database. If it already exists, do nothing. """ try: - token = data["token"] - hashed_token = self.hash_token(token=token) - db_data = self.jsonify_object(data=data) - db_data["token"] = hashed_token - max_budget = db_data.pop("max_budget", None) - user_email = db_data.pop("user_email", None) - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore - where={ - "token": hashed_token, - }, - data={ - "create": {**db_data}, # type: ignore - "update": {}, # don't do anything if it already exists - }, - ) - - new_user_row = await self.db.litellm_usertable.upsert( - where={"user_id": data["user_id"]}, - data={ - "create": { - "user_id": data["user_id"], - "max_budget": max_budget, - "user_email": user_email, + if table_name == "user+key": + token = data["token"] + hashed_token = self.hash_token(token=token) + db_data = self.jsonify_object(data=data) + db_data["token"] = hashed_token + max_budget = db_data.pop("max_budget", None) + user_email = db_data.pop("user_email", None) + new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore + where={ + "token": hashed_token, }, - "update": {}, # don't do anything if it already exists - }, - ) - return new_verification_token + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + + new_user_row = await self.db.litellm_usertable.upsert( + where={"user_id": data["user_id"]}, + data={ + "create": { + "user_id": data["user_id"], + "max_budget": max_budget, + "user_email": user_email, + }, + "update": {}, # don't do anything if it already exists + }, + ) + return new_verification_token + elif table_name == "config": + """ + For each param, + get the existing table values + + Add the new values + + Update DB + """ + tasks = [] + for k, v in data.items(): + updated_data = v + updated_data = json.dumps(updated_data) + updated_table_row = self.db.litellm_config.upsert( + where={"param_name": k}, + data={ + "create": {"param_name": k, "param_value": updated_data}, + "update": {"param_value": updated_data}, + }, + ) + + tasks.append(updated_table_row) + + await asyncio.gather(*tasks) except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) @@ -527,6 +558,7 @@ class PrismaClient: async def disconnect(self): try: await self.db.disconnect() + self.connected = False except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) diff --git a/litellm/utils.py b/litellm/utils.py index f62c79c22..9ae6e3498 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9,7 +9,7 @@ import sys, re, binascii, struct import litellm -import dotenv, json, traceback, threading, base64 +import dotenv, json, traceback, threading, base64, ast import subprocess, os import litellm, openai import itertools @@ -6621,7 +6621,7 @@ def _is_base64(s): def get_secret( secret_name: str, - default_value: Optional[str] = None, + default_value: Optional[Union[str, bool]] = None, ): key_management_system = litellm._key_management_system if secret_name.startswith("os.environ/"): @@ -6672,9 +6672,24 @@ def get_secret( secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ secret = os.getenv(secret_name) - return secret + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + else: + return secret + except: + return secret else: - return os.environ.get(secret_name) + secret = os.environ.get(secret_name) + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + else: + return secret + except: + return secret except Exception as e: if default_value is not None: return default_value