diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ee620697af..6447f89892 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -305,6 +305,7 @@ litellm_master_key_hash = None disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None +store_model_in_db: bool = False ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -1922,7 +1923,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -2262,6 +2263,10 @@ class ProxyConfig: if master_key is not None and isinstance(master_key, str): litellm_master_key_hash = master_key + ### STORE MODEL IN DB ### feature flag for `/model/new` + store_model_in_db = general_settings.get("store_model_in_db", False) + if store_model_in_db is None: + store_model_in_db = False ### CUSTOM API KEY AUTH ### ## pass filepath custom_auth = general_settings.get("custom_auth", None) @@ -2327,13 +2332,6 @@ class ProxyConfig: "background_health_checks", False ) health_check_interval = general_settings.get("health_check_interval", 300) - ## check env ## - _store_model_in_db = litellm.get_secret( - "STORE_MODEL_IN_DB", None - ) # feature flag for `/model/new` - verbose_proxy_logger.info(f"'STORE_MODEL_IN_DB'={_store_model_in_db}") - if _store_model_in_db is not None and _store_model_in_db == True: - general_settings["store_model_in_db"] = True router_params: dict = { "cache_responses": litellm.cache != None, # cache if user passed in cache values @@ -2390,55 +2388,95 @@ class ProxyConfig: import base64 try: - if llm_router is None: - raise Exception("No router initialized") - if master_key is None or not isinstance(master_key, str): raise Exception( f"Master key is not initialized or formatted. master_key={master_key}" ) - new_models = await prisma_client.db.litellm_proxymodeltable.find_many( - take=10, order={"updated_at": "desc"} - ) + if llm_router is None: + new_models = ( + await prisma_client.db.litellm_proxymodeltable.find_many() + ) # get all models in db + _model_list: list = [] + for m in new_models: + _litellm_params = m.litellm_params + if isinstance(_litellm_params, dict): + # decrypt values + for k, v in _litellm_params.items(): + if isinstance(v, str): + # decode base64 + decoded_b64 = base64.b64decode(v) + # decrypt value + _litellm_params[k] = decrypt_value( + value=decoded_b64, master_key=master_key + ) + _litellm_params = LiteLLM_Params(**_litellm_params) + else: + verbose_proxy_logger.error( + f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" + ) + continue # skip to next model - for m in new_models: - _litellm_params = m.litellm_params - if isinstance(_litellm_params, dict): - # decrypt values - for k, v in _litellm_params.items(): - if isinstance(v, str): - # decode base64 - decoded_b64 = base64.b64decode(v) - # decrypt value - _litellm_params[k] = decrypt_value( - value=decoded_b64, master_key=master_key - ) - _litellm_params = LiteLLM_Params(**_litellm_params) - else: - verbose_proxy_logger.error( - f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" + if m.model_info is not None and isinstance(m.model_info, dict): + if "id" not in m.model_info: + m.model_info["id"] = m.model_id + _model_info = RouterModelInfo(**m.model_info) + else: + _model_info = RouterModelInfo(id=m.model_id) + + _model_list.append( + Deployment( + model_name=m.model_name, + litellm_params=_litellm_params, + model_info=_model_info, + ).to_json(exclude_none=True) ) - continue # skip to next model - if m.model_info is not None and isinstance(m.model_info, dict): - if "id" not in m.model_info: - m.model_info["id"] = m.model_id - _model_info = RouterModelInfo(**m.model_info) - else: - _model_info = RouterModelInfo(id=m.model_id) - - llm_router.add_deployment( - deployment=Deployment( - model_name=m.model_name, - litellm_params=_litellm_params, - model_info=_model_info, - ) + llm_router = litellm.Router(model_list=_model_list) + else: + new_models = await prisma_client.db.litellm_proxymodeltable.find_many( + take=10, order={"updated_at": "desc"} ) + for m in new_models: + _litellm_params = m.litellm_params + if isinstance(_litellm_params, dict): + # decrypt values + for k, v in _litellm_params.items(): + if isinstance(v, str): + # decode base64 + decoded_b64 = base64.b64decode(v) + # decrypt value + _litellm_params[k] = decrypt_value( + value=decoded_b64, master_key=master_key + ) + _litellm_params = LiteLLM_Params(**_litellm_params) + else: + verbose_proxy_logger.error( + f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" + ) + continue # skip to next model + + if m.model_info is not None and isinstance(m.model_info, dict): + if "id" not in m.model_info: + m.model_info["id"] = m.model_id + _model_info = RouterModelInfo(**m.model_info) + else: + _model_info = RouterModelInfo(id=m.model_id) + + llm_router.add_deployment( + deployment=Deployment( + model_name=m.model_name, + litellm_params=_litellm_params, + model_info=_model_info, + ) + ) + llm_model_list = llm_router.get_model_list() except Exception as e: - verbose_proxy_logger.error("{}".format(str(e))) + verbose_proxy_logger.error( + "{}\nTraceback:{}".format(str(e), traceback.format_exc()) + ) proxy_config = ProxyConfig() @@ -2894,7 +2932,7 @@ def on_backoff(details): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db import json ### LOAD MASTER KEY ### @@ -3043,7 +3081,8 @@ async def startup_event(): ) ### ADD NEW MODELS ### - if general_settings.get("store_model_in_db", False) == True: + store_model_in_db = litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db) + if store_model_in_db == True: scheduler.add_job( proxy_config.add_deployment, "interval", @@ -6854,7 +6893,7 @@ async def add_new_model( model_params: Deployment, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): - global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db try: import base64 @@ -6869,7 +6908,7 @@ async def add_new_model( ) # update DB - if general_settings.get("store_model_in_db", False) == True: + if store_model_in_db == True: """ - store model_list in db - store keys separately @@ -7156,7 +7195,7 @@ async def delete_model(model_info: ModelInfoDelete): ) # update DB - if general_settings.get("store_model_in_db", False) == True: + if store_model_in_db == True: """ - store model_list in db - store keys separately