diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ffa5f1669..7d3afeb0f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -307,9 +307,8 @@ async def user_api_key_auth( ) -def prisma_setup(database_url: Optional[str]): +async def prisma_setup(database_url: Optional[str]): global prisma_client, proxy_logging_obj, user_api_key_cache - if ( database_url is not None and prisma_client is None ): # don't re-initialize prisma client after initial init @@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]): print_verbose( f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}" ) + if prisma_client is not None and prisma_client.db.is_connected() == False: + await prisma_client.connect() def load_from_azure_key_vault(use_azure_key_vault: bool = False): @@ -534,6 +535,7 @@ class ProxyConfig: prisma_client is not None and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True ): + await prisma_setup(database_url=None) # in case it's not been connected yet _tasks = [] keys = [ "model_list", @@ -761,7 +763,7 @@ class ProxyConfig: print_verbose(f"GOING INTO LITELLM.GET_SECRET!") database_url = litellm.get_secret(database_url) print_verbose(f"RETRIEVED DB URL: {database_url}") - prisma_setup(database_url=database_url) + await prisma_setup(database_url=database_url) ## COST TRACKING ## cost_tracking() ### MASTER KEY ### @@ -930,7 +932,7 @@ def save_worker_config(**data): os.environ["WORKER_CONFIG"] = json.dumps(data) -def initialize( +async def initialize( model=None, alias=None, api_base=None, @@ -948,13 +950,19 @@ def initialize( use_queue=False, config=None, ): - global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth + global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client generate_feedback_box() user_model = model user_debug = debug if debug == True: # this needs to be first, so users can see Router init debugg litellm.set_verbose = True dynamic_config = {"general": {}, user_model: {}} + if config: + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config(router=llm_router, config_file_path=config) if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers @@ -1095,28 +1103,11 @@ async def startup_event(): print_verbose(f"worker_config: {worker_config}") # check if it's a valid file path if os.path.isfile(worker_config): - if worker_config.get("config", None) is not None: - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config.pop("config") - ) - initialize(**worker_config) + await initialize(**worker_config) else: # if not, assume it's a json string worker_config = json.loads(os.getenv("WORKER_CONFIG")) - if worker_config.get("config", None) is not None: - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config.pop("config") - ) - initialize(**worker_config) - + await initialize(**worker_config) proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made if use_background_health_checks: @@ -1124,10 +1115,6 @@ async def startup_event(): _run_background_health_check() ) # start the background health check coroutine. - print_verbose(f"prisma client - {prisma_client}") - if prisma_client is not None: - await prisma_client.connect() - if prisma_client is not None and master_key is not None: # add master key to db await generate_key_helper_fn( @@ -1331,7 +1318,7 @@ async def chat_completion( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks(), ): - global general_settings, user_debug, proxy_logging_obj + global general_settings, user_debug, proxy_logging_obj, llm_model_list try: data = {} body = await request.body() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0be448119..3b90a2ad5 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -255,7 +255,6 @@ class PrismaClient: ) ## init logging object self.proxy_logging_obj = proxy_logging_obj - self.connected = False os.environ["DATABASE_URL"] = database_url # Save the current working directory original_dir = os.getcwd() @@ -536,11 +535,7 @@ class PrismaClient: ) async def connect(self): try: - if self.connected == False: - await self.db.connect() - self.connected = True - else: - return + await self.db.connect() except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) @@ -558,7 +553,6 @@ class PrismaClient: async def disconnect(self): try: await self.db.disconnect() - self.connected = False except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py index f16f1d379..ceb3d1c93 100644 --- a/litellm/tests/test_proxy_custom_auth.py +++ b/litellm/tests/test_proxy_custom_auth.py @@ -10,7 +10,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest +import pytest, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError @@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import ( router, save_worker_config, initialize, + ProxyConfig, ) # Replace with the actual module where your FastAPI router is defined @@ -36,7 +37,7 @@ def client(): config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables app = FastAPI() - initialize(config=config_fp) + asyncio.run(initialize(config=config_fp)) app.include_router(router) # Include your router in the test app return TestClient(app) diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py index f8828d137..e47351a9b 100644 --- a/litellm/tests/test_proxy_custom_logger.py +++ b/litellm/tests/test_proxy_custom_logger.py @@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import ( router, save_worker_config, initialize, + startup_event, ) # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) @@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py" def client(): filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_custom_logger.yaml" - initialize(config=config_fp) app = FastAPI() + asyncio.run(initialize(config=config_fp)) app.include_router(router) # Include your router in the test app return TestClient(app) diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index ff3b358a9..d5be29a61 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import ( def client(): filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_bad_config.yaml" - initialize(config=config_fp) + asyncio.run(initialize(config=config_fp)) app = FastAPI() app.include_router(router) # Include your router in the test app return TestClient(app) diff --git a/litellm/tests/test_proxy_pass_user_config.py b/litellm/tests/test_proxy_pass_user_config.py index ea5f189c2..30fa1eeb1 100644 --- a/litellm/tests/test_proxy_pass_user_config.py +++ b/litellm/tests/test_proxy_pass_user_config.py @@ -10,7 +10,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest, logging +import pytest, logging, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError @@ -46,7 +46,7 @@ def client_no_auth(): filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables - initialize(config=config_fp, debug=True) + asyncio.run(initialize(config=config_fp, debug=True)) app = FastAPI() app.include_router(router) # Include your router in the test app diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index b7b4b0c40..0fb8c742a 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -10,7 +10,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest, logging +import pytest, logging, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError @@ -45,7 +45,7 @@ def client_no_auth(): filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables - initialize(config=config_fp, debug=True) + asyncio.run(initialize(config=config_fp, debug=True)) app = FastAPI() app.include_router(router) # Include your router in the test app diff --git a/litellm/tests/test_proxy_server_caching.py b/litellm/tests/test_proxy_server_caching.py index cb8ca7609..a1935bd05 100644 --- a/litellm/tests/test_proxy_server_caching.py +++ b/litellm/tests/test_proxy_server_caching.py @@ -12,7 +12,7 @@ import os, io sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest, logging +import pytest, logging, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError @@ -47,7 +47,7 @@ def client_no_auth(): filepath = os.path.dirname(os.path.abspath(__file__)) config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml" # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables - initialize(config=config_fp, debug=True) + asyncio.run(initialize(config=config_fp, debug=True)) app = FastAPI() app.include_router(router) # Include your router in the test app